use crate::backends::{GenerateParams, GeneratedToken, LlmBackend, Tokenizer};
use crate::error::{Result, RuvLLMError};
use parking_lot::RwLock;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpeculativeConfig {
pub lookahead: usize,
pub acceptance_threshold: f32,
pub draft_temperature: f32,
pub tree_speculation: bool,
pub max_tree_depth: usize,
pub tree_branching_factor: usize,
pub draft_top_p: f32,
pub min_acceptance_ratio: f32,
pub adaptive_lookahead: bool,
pub min_lookahead: usize,
pub max_lookahead: usize,
}
impl Default for SpeculativeConfig {
fn default() -> Self {
Self {
lookahead: 4,
acceptance_threshold: 0.5,
draft_temperature: 0.0,
tree_speculation: false,
max_tree_depth: 3,
tree_branching_factor: 2,
draft_top_p: 1.0,
min_acceptance_ratio: 0.1,
adaptive_lookahead: true,
min_lookahead: 2,
max_lookahead: 8,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SpeculativeStats {
pub draft_tokens: usize,
pub accepted_tokens: usize,
pub acceptance_rate: f32,
pub speedup: f32,
pub main_forward_passes: usize,
pub draft_forward_passes: usize,
pub avg_tokens_per_main_pass: f32,
pub total_speculation_time_ms: f64,
pub total_tokens_generated: usize,
}
impl SpeculativeStats {
pub fn new() -> Self {
Self::default()
}
pub fn update_acceptance_rate(&mut self) {
if self.draft_tokens > 0 {
self.acceptance_rate = self.accepted_tokens as f32 / self.draft_tokens as f32;
}
}
pub fn calculate_speedup(&mut self) {
if self.main_forward_passes > 0 {
self.avg_tokens_per_main_pass =
self.total_tokens_generated as f32 / self.main_forward_passes as f32;
self.speedup = self.avg_tokens_per_main_pass;
}
}
pub fn record_round(
&mut self,
draft_count: usize,
accepted_count: usize,
speculation_time_ms: f64,
) {
self.draft_tokens += draft_count;
self.accepted_tokens += accepted_count;
self.draft_forward_passes += draft_count;
self.main_forward_passes += 1;
self.total_tokens_generated += accepted_count + 1; self.total_speculation_time_ms += speculation_time_ms;
self.update_acceptance_rate();
self.calculate_speedup();
}
pub fn reset(&mut self) {
*self = Self::default();
}
}
pub struct AtomicSpeculativeStats {
draft_tokens: AtomicUsize,
accepted_tokens: AtomicUsize,
main_forward_passes: AtomicUsize,
draft_forward_passes: AtomicUsize,
total_tokens_generated: AtomicUsize,
total_speculation_time_ns: AtomicU64,
}
impl Default for AtomicSpeculativeStats {
fn default() -> Self {
Self::new()
}
}
impl AtomicSpeculativeStats {
pub fn new() -> Self {
Self {
draft_tokens: AtomicUsize::new(0),
accepted_tokens: AtomicUsize::new(0),
main_forward_passes: AtomicUsize::new(0),
draft_forward_passes: AtomicUsize::new(0),
total_tokens_generated: AtomicUsize::new(0),
total_speculation_time_ns: AtomicU64::new(0),
}
}
pub fn record_round(&self, draft_count: usize, accepted_count: usize, duration: Duration) {
self.draft_tokens.fetch_add(draft_count, Ordering::Relaxed);
self.accepted_tokens
.fetch_add(accepted_count, Ordering::Relaxed);
self.main_forward_passes.fetch_add(1, Ordering::Relaxed);
self.draft_forward_passes
.fetch_add(draft_count, Ordering::Relaxed);
self.total_tokens_generated
.fetch_add(accepted_count + 1, Ordering::Relaxed);
self.total_speculation_time_ns
.fetch_add(duration.as_nanos() as u64, Ordering::Relaxed);
}
pub fn snapshot(&self) -> SpeculativeStats {
let draft_tokens = self.draft_tokens.load(Ordering::Relaxed);
let accepted_tokens = self.accepted_tokens.load(Ordering::Relaxed);
let main_forward_passes = self.main_forward_passes.load(Ordering::Relaxed);
let total_tokens_generated = self.total_tokens_generated.load(Ordering::Relaxed);
let total_speculation_time_ns = self.total_speculation_time_ns.load(Ordering::Relaxed);
let acceptance_rate = if draft_tokens > 0 {
accepted_tokens as f32 / draft_tokens as f32
} else {
0.0
};
let avg_tokens_per_main_pass = if main_forward_passes > 0 {
total_tokens_generated as f32 / main_forward_passes as f32
} else {
0.0
};
SpeculativeStats {
draft_tokens,
accepted_tokens,
acceptance_rate,
speedup: avg_tokens_per_main_pass,
main_forward_passes,
draft_forward_passes: self.draft_forward_passes.load(Ordering::Relaxed),
avg_tokens_per_main_pass,
total_speculation_time_ms: total_speculation_time_ns as f64 / 1_000_000.0,
total_tokens_generated,
}
}
pub fn reset(&self) {
self.draft_tokens.store(0, Ordering::Relaxed);
self.accepted_tokens.store(0, Ordering::Relaxed);
self.main_forward_passes.store(0, Ordering::Relaxed);
self.draft_forward_passes.store(0, Ordering::Relaxed);
self.total_tokens_generated.store(0, Ordering::Relaxed);
self.total_speculation_time_ns.store(0, Ordering::Relaxed);
}
}
#[derive(Debug, Clone)]
pub struct VerificationResult {
pub accepted_count: usize,
pub next_token: u32,
pub accepted_logprobs: Vec<f32>,
pub next_logprob: f32,
pub all_accepted: bool,
}
#[derive(Debug, Clone)]
pub struct TreeNode {
pub token: u32,
pub prob: f32,
pub logprob: f32,
pub children: Vec<TreeNode>,
pub depth: usize,
}
impl TreeNode {
pub fn new(token: u32, prob: f32, depth: usize) -> Self {
Self {
token,
prob,
logprob: prob.ln(),
children: Vec::new(),
depth,
}
}
pub fn add_child(&mut self, token: u32, prob: f32) -> &mut TreeNode {
let child = TreeNode::new(token, prob, self.depth + 1);
self.children.push(child);
self.children
.last_mut()
.expect("children is non-empty after push")
}
pub fn get_paths(&self) -> Vec<Vec<u32>> {
if self.children.is_empty() {
return vec![vec![self.token]];
}
let mut paths = Vec::new();
for child in &self.children {
for mut path in child.get_paths() {
path.insert(0, self.token);
paths.push(path);
}
}
paths
}
pub fn best_path(&self) -> Vec<u32> {
if self.children.is_empty() {
return vec![self.token];
}
let best_child = self
.children
.iter()
.max_by(|a, b| {
a.prob
.partial_cmp(&b.prob)
.unwrap_or(std::cmp::Ordering::Equal)
})
.expect("children is non-empty");
let mut path = vec![self.token];
path.extend(best_child.best_path());
path
}
}
#[derive(Debug)]
pub struct SpeculationTree {
pub root: TreeNode,
pub max_depth: usize,
pub branching_factor: usize,
pub node_count: usize,
}
impl SpeculationTree {
pub fn new(max_depth: usize, branching_factor: usize) -> Self {
Self {
root: TreeNode::new(0, 1.0, 0),
max_depth,
branching_factor,
node_count: 1,
}
}
pub fn get_candidate_paths(&self) -> Vec<Vec<u32>> {
self.root.get_paths()
}
pub fn best_path(&self) -> Vec<u32> {
let path = self.root.best_path();
if path.len() > 1 {
path[1..].to_vec()
} else {
Vec::new()
}
}
pub fn clear(&mut self) {
self.root = TreeNode::new(0, 1.0, 0);
self.node_count = 1;
}
}
pub struct SpeculativeDecoder<M: LlmBackend + ?Sized, D: LlmBackend + ?Sized> {
main_model: Arc<M>,
draft_model: Arc<D>,
config: RwLock<SpeculativeConfig>,
stats: AtomicSpeculativeStats,
current_lookahead: AtomicUsize,
rng_seed: AtomicU64,
}
impl<M: LlmBackend + ?Sized, D: LlmBackend + ?Sized> SpeculativeDecoder<M, D> {
pub fn new(main_model: Arc<M>, draft_model: Arc<D>, config: SpeculativeConfig) -> Self {
let lookahead = config.lookahead;
Self {
main_model,
draft_model,
config: RwLock::new(config),
stats: AtomicSpeculativeStats::new(),
current_lookahead: AtomicUsize::new(lookahead),
rng_seed: AtomicU64::new(42),
}
}
pub fn config(&self) -> SpeculativeConfig {
self.config.read().clone()
}
pub fn set_config(&self, config: SpeculativeConfig) {
*self.config.write() = config;
}
pub fn stats(&self) -> SpeculativeStats {
self.stats.snapshot()
}
pub fn reset_stats(&self) {
self.stats.reset();
}
pub fn tokenizer(&self) -> Option<&dyn Tokenizer> {
self.main_model.tokenizer()
}
fn tokenize(&self, text: &str) -> Result<Vec<u32>> {
let tokenizer = self
.main_model
.tokenizer()
.ok_or_else(|| RuvLLMError::InvalidOperation("No tokenizer available".to_string()))?;
tokenizer.encode(text)
}
fn decode(&self, tokens: &[u32]) -> Result<String> {
let tokenizer = self
.main_model
.tokenizer()
.ok_or_else(|| RuvLLMError::InvalidOperation("No tokenizer available".to_string()))?;
tokenizer.decode(tokens)
}
pub fn should_use_speculative(&self, params: &GenerateParams) -> bool {
params.temperature < 0.5 || params.top_k == 1
}
pub fn generate(&self, prompt: &str, params: GenerateParams) -> Result<String> {
let tokens = self.tokenize(prompt)?;
let generated = self.generate_tokens(&tokens, ¶ms)?;
self.decode(&generated)
}
pub fn generate_tokens(
&self,
prompt_tokens: &[u32],
params: &GenerateParams,
) -> Result<Vec<u32>> {
let config = self.config.read().clone();
let mut context = prompt_tokens.to_vec();
let mut output = Vec::new();
let eos_token = self
.main_model
.tokenizer()
.and_then(|t| t.special_tokens().eos_token_id);
while output.len() < params.max_tokens {
let start = Instant::now();
let lookahead = if config.adaptive_lookahead {
self.current_lookahead.load(Ordering::Relaxed)
} else {
config.lookahead
};
let draft_tokens = self.draft_phase(&context, lookahead, &config)?;
if draft_tokens.is_empty() {
let main_token = self.single_main_forward(&context, params)?;
if Some(main_token) == eos_token {
break;
}
context.push(main_token);
output.push(main_token);
continue;
}
let verification = self.verify_phase(&context, &draft_tokens, params)?;
let accepted = &draft_tokens[..verification.accepted_count];
context.extend_from_slice(accepted);
output.extend_from_slice(accepted);
if Some(verification.next_token) == eos_token {
break;
}
context.push(verification.next_token);
output.push(verification.next_token);
let duration = start.elapsed();
self.stats
.record_round(draft_tokens.len(), verification.accepted_count, duration);
if config.adaptive_lookahead {
self.adjust_lookahead(verification.accepted_count, draft_tokens.len(), &config);
}
if !params.stop_sequences.is_empty() {
let current_text = self.decode(&output)?;
for stop_seq in ¶ms.stop_sequences {
if current_text.contains(stop_seq) {
let trimmed = current_text.split(stop_seq).next().unwrap_or("");
return self
.tokenize(trimmed)
.map(|t| t.into_iter().skip(prompt_tokens.len()).collect());
}
}
}
}
Ok(output)
}
fn draft_phase(
&self,
context: &[u32],
k: usize,
config: &SpeculativeConfig,
) -> Result<Vec<u32>> {
let mut draft = Vec::with_capacity(k);
let mut ctx = context.to_vec();
let prompt_text = self.decode(&ctx)?;
for i in 0..k {
let draft_params = GenerateParams {
max_tokens: 1,
temperature: config.draft_temperature,
top_p: config.draft_top_p,
top_k: if config.draft_temperature == 0.0 {
1
} else {
40
},
..Default::default()
};
let current_prompt = self.decode(&ctx)?;
let generated = self
.draft_model
.generate(¤t_prompt, draft_params.clone())?;
let generated_tokens = self.tokenize(&format!("{}{}", prompt_text, generated))?;
if generated_tokens.len() <= ctx.len() {
break;
}
let new_token = generated_tokens[ctx.len()];
draft.push(new_token);
ctx.push(new_token);
if let Some(eos) = self
.draft_model
.tokenizer()
.and_then(|t| t.special_tokens().eos_token_id)
{
if new_token == eos {
break;
}
}
}
Ok(draft)
}
fn verify_phase(
&self,
context: &[u32],
draft_tokens: &[u32],
params: &GenerateParams,
) -> Result<VerificationResult> {
let config = self.config.read();
let mut accepted_count = 0;
let mut accepted_logprobs = Vec::new();
let mut ctx = context.to_vec();
for (i, &draft_token) in draft_tokens.iter().enumerate() {
let prompt_text = self.decode(&ctx)?;
let main_params = GenerateParams {
max_tokens: 1,
temperature: params.temperature,
top_p: params.top_p,
top_k: params.top_k,
..params.clone()
};
let main_generated = self
.main_model
.generate(&prompt_text, main_params.clone())?;
let main_tokens = self.tokenize(&format!("{}{}", prompt_text, main_generated))?;
if main_tokens.len() <= ctx.len() {
let next_token = self.single_main_forward(&ctx, params)?;
return Ok(VerificationResult {
accepted_count,
next_token,
accepted_logprobs,
next_logprob: 0.0,
all_accepted: false,
});
}
let main_token = main_tokens[ctx.len()];
if main_token == draft_token {
accepted_count += 1;
accepted_logprobs.push(0.0); ctx.push(draft_token);
} else {
return Ok(VerificationResult {
accepted_count,
next_token: main_token,
accepted_logprobs,
next_logprob: 0.0,
all_accepted: false,
});
}
}
let next_token = self.single_main_forward(&ctx, params)?;
Ok(VerificationResult {
accepted_count,
next_token,
accepted_logprobs,
next_logprob: 0.0,
all_accepted: true,
})
}
fn single_main_forward(&self, context: &[u32], params: &GenerateParams) -> Result<u32> {
let prompt_text = self.decode(context)?;
let main_params = GenerateParams {
max_tokens: 1,
temperature: params.temperature,
top_p: params.top_p,
top_k: params.top_k,
..params.clone()
};
let generated = self.main_model.generate(&prompt_text, main_params)?;
let tokens = self.tokenize(&format!("{}{}", prompt_text, generated))?;
if tokens.len() > context.len() {
Ok(tokens[context.len()])
} else {
Ok(self
.main_model
.tokenizer()
.and_then(|t| t.special_tokens().eos_token_id)
.unwrap_or(0))
}
}
fn adjust_lookahead(&self, accepted: usize, total: usize, config: &SpeculativeConfig) {
let current = self.current_lookahead.load(Ordering::Relaxed);
let acceptance_rate = if total > 0 {
accepted as f32 / total as f32
} else {
0.5
};
let new_lookahead = if acceptance_rate > 0.9 {
(current + 1).min(config.max_lookahead)
} else if acceptance_rate < 0.5 {
current.saturating_sub(1).max(config.min_lookahead)
} else {
current
};
self.current_lookahead
.store(new_lookahead, Ordering::Relaxed);
}
pub fn generate_tree(&self, prompt: &str, params: GenerateParams) -> Result<String> {
let config = self.config.read().clone();
if !config.tree_speculation {
return self.generate(prompt, params);
}
let tokens = self.tokenize(prompt)?;
let mut context = tokens.clone();
let mut output = Vec::new();
let eos_token = self
.main_model
.tokenizer()
.and_then(|t| t.special_tokens().eos_token_id);
while output.len() < params.max_tokens {
let start = Instant::now();
let tree = self.build_speculation_tree(&context, &config)?;
let best_path = tree.best_path();
if best_path.is_empty() {
let main_token = self.single_main_forward(&context, ¶ms)?;
if Some(main_token) == eos_token {
break;
}
context.push(main_token);
output.push(main_token);
continue;
}
let verification = self.verify_phase(&context, &best_path, ¶ms)?;
let accepted = &best_path[..verification.accepted_count];
context.extend_from_slice(accepted);
output.extend_from_slice(accepted);
if Some(verification.next_token) == eos_token {
break;
}
context.push(verification.next_token);
output.push(verification.next_token);
self.stats.record_round(
best_path.len(),
verification.accepted_count,
start.elapsed(),
);
}
self.decode(&output)
}
fn build_speculation_tree(
&self,
context: &[u32],
config: &SpeculativeConfig,
) -> Result<SpeculationTree> {
let mut tree = SpeculationTree::new(config.max_tree_depth, config.tree_branching_factor);
let draft_tokens = self.draft_phase(context, config.max_tree_depth, config)?;
let mut current = &mut tree.root;
for (i, &token) in draft_tokens.iter().enumerate() {
current = current.add_child(token, 1.0 / (i + 1) as f32);
tree.node_count += 1;
}
Ok(tree)
}
pub fn generate_stream<'a>(
&'a self,
prompt: &str,
params: GenerateParams,
) -> Result<impl Iterator<Item = Result<GeneratedToken>> + 'a> {
let tokens = self.tokenize(prompt)?;
let context = tokens.clone();
let config = self.config.read().clone();
Ok(SpeculativeStreamIterator {
decoder: self,
context,
params,
config,
output_count: 0,
pending_tokens: Vec::new(),
finished: false,
})
}
}
struct SpeculativeStreamIterator<'a, M: LlmBackend + ?Sized, D: LlmBackend + ?Sized> {
decoder: &'a SpeculativeDecoder<M, D>,
context: Vec<u32>,
params: GenerateParams,
config: SpeculativeConfig,
output_count: usize,
pending_tokens: Vec<u32>,
finished: bool,
}
impl<'a, M: LlmBackend + ?Sized, D: LlmBackend + ?Sized> Iterator
for SpeculativeStreamIterator<'a, M, D>
{
type Item = Result<GeneratedToken>;
fn next(&mut self) -> Option<Self::Item> {
if self.finished || self.output_count >= self.params.max_tokens {
return None;
}
if !self.pending_tokens.is_empty() {
let token = self.pending_tokens.remove(0);
self.output_count += 1;
let text = self.decoder.decode(&[token]).unwrap_or_default();
return Some(Ok(GeneratedToken {
id: token,
text,
logprob: None,
is_special: false,
}));
}
let lookahead = self.config.lookahead;
let draft_result = self
.decoder
.draft_phase(&self.context, lookahead, &self.config);
match draft_result {
Ok(draft_tokens) if !draft_tokens.is_empty() => {
match self
.decoder
.verify_phase(&self.context, &draft_tokens, &self.params)
{
Ok(verification) => {
let accepted = &draft_tokens[..verification.accepted_count];
self.pending_tokens.extend_from_slice(accepted);
self.pending_tokens.push(verification.next_token);
self.context.extend_from_slice(accepted);
self.context.push(verification.next_token);
self.next()
}
Err(e) => {
self.finished = true;
Some(Err(e))
}
}
}
Ok(_) => {
match self
.decoder
.single_main_forward(&self.context, &self.params)
{
Ok(token) => {
self.context.push(token);
self.output_count += 1;
let eos = self
.decoder
.main_model
.tokenizer()
.and_then(|t| t.special_tokens().eos_token_id);
if Some(token) == eos {
self.finished = true;
}
let text = self.decoder.decode(&[token]).unwrap_or_default();
Some(Ok(GeneratedToken {
id: token,
text,
logprob: None,
is_special: Some(token) == eos,
}))
}
Err(e) => {
self.finished = true;
Some(Err(e))
}
}
}
Err(e) => {
self.finished = true;
Some(Err(e))
}
}
}
}
pub fn softmax(logits: &[f32]) -> Vec<f32> {
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
softmax_neon_optimized(logits)
}
#[cfg(not(all(target_arch = "aarch64", target_feature = "neon")))]
{
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
logits
.iter()
.map(|&x| (x - max_logit).exp() / exp_sum)
.collect()
}
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
fn softmax_neon_optimized(logits: &[f32]) -> Vec<f32> {
use std::arch::aarch64::*;
const UNROLL_8X: usize = 8;
if logits.is_empty() {
return vec![];
}
let mut result = vec![0.0f32; logits.len()];
unsafe {
let mut max_vec = vdupq_n_f32(f32::NEG_INFINITY);
let chunks = logits.len() / UNROLL_8X;
for c in 0..chunks {
let base = c * UNROLL_8X;
let v0 = vld1q_f32(logits.as_ptr().add(base));
let v1 = vld1q_f32(logits.as_ptr().add(base + 4));
max_vec = vmaxq_f32(max_vec, vmaxq_f32(v0, v1));
}
let mut max_logit = vmaxvq_f32(max_vec);
for i in (chunks * UNROLL_8X)..logits.len() {
max_logit = max_logit.max(logits[i]);
}
let max_vec = vdupq_n_f32(max_logit);
let one = vdupq_n_f32(1.0);
let half = vdupq_n_f32(0.5);
let sixth = vdupq_n_f32(1.0 / 6.0);
let twenty_fourth = vdupq_n_f32(1.0 / 24.0);
let one_twenty = vdupq_n_f32(1.0 / 120.0);
let seven_twenty = vdupq_n_f32(1.0 / 720.0);
let mut sum0 = vdupq_n_f32(0.0);
let mut sum1 = vdupq_n_f32(0.0);
#[inline(always)]
unsafe fn fast_exp_vec(
x: float32x4_t,
one: float32x4_t,
half: float32x4_t,
sixth: float32x4_t,
twenty_fourth: float32x4_t,
one_twenty: float32x4_t,
seven_twenty: float32x4_t,
) -> float32x4_t {
let x = vmaxq_f32(vdupq_n_f32(-20.0), vminq_f32(x, vdupq_n_f32(20.0)));
let x2 = vmulq_f32(x, x);
let x3 = vmulq_f32(x2, x);
let x4 = vmulq_f32(x2, x2);
let x5 = vmulq_f32(x4, x);
let x6 = vmulq_f32(x3, x3);
let result = vaddq_f32(one, x);
let result = vfmaq_f32(result, x2, half);
let result = vfmaq_f32(result, x3, sixth);
let result = vfmaq_f32(result, x4, twenty_fourth);
let result = vfmaq_f32(result, x5, one_twenty);
let result = vfmaq_f32(result, x6, seven_twenty);
vmaxq_f32(result, vdupq_n_f32(0.0))
}
for c in 0..chunks {
let base = c * UNROLL_8X;
let v0 = vld1q_f32(logits.as_ptr().add(base));
let v1 = vld1q_f32(logits.as_ptr().add(base + 4));
let d0 = vsubq_f32(v0, max_vec);
let d1 = vsubq_f32(v1, max_vec);
let e0 = fast_exp_vec(
d0,
one,
half,
sixth,
twenty_fourth,
one_twenty,
seven_twenty,
);
let e1 = fast_exp_vec(
d1,
one,
half,
sixth,
twenty_fourth,
one_twenty,
seven_twenty,
);
vst1q_f32(result.as_mut_ptr().add(base), e0);
vst1q_f32(result.as_mut_ptr().add(base + 4), e1);
sum0 = vaddq_f32(sum0, e0);
sum1 = vaddq_f32(sum1, e1);
}
let mut exp_sum = vaddvq_f32(vaddq_f32(sum0, sum1));
for i in (chunks * UNROLL_8X)..logits.len() {
let e = (logits[i] - max_logit).exp();
result[i] = e;
exp_sum += e;
}
let inv_sum = vdupq_n_f32(1.0 / exp_sum);
for c in 0..chunks {
let base = c * UNROLL_8X;
let e0 = vld1q_f32(result.as_ptr().add(base));
let e1 = vld1q_f32(result.as_ptr().add(base + 4));
let p0 = vmulq_f32(e0, inv_sum);
let p1 = vmulq_f32(e1, inv_sum);
vst1q_f32(result.as_mut_ptr().add(base), p0);
vst1q_f32(result.as_mut_ptr().add(base + 4), p1);
}
let inv_sum_scalar = 1.0 / exp_sum;
for i in (chunks * UNROLL_8X)..logits.len() {
result[i] *= inv_sum_scalar;
}
}
result
}
pub fn log_softmax(logits: &[f32]) -> Vec<f32> {
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
log_softmax_neon_optimized(logits)
}
#[cfg(not(all(target_arch = "aarch64", target_feature = "neon")))]
{
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let log_sum_exp: f32 = logits
.iter()
.map(|&x| (x - max_logit).exp())
.sum::<f32>()
.ln()
+ max_logit;
logits.iter().map(|&x| x - log_sum_exp).collect()
}
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
fn log_softmax_neon_optimized(logits: &[f32]) -> Vec<f32> {
use std::arch::aarch64::*;
const UNROLL_8X: usize = 8;
if logits.is_empty() {
return vec![];
}
let mut result = vec![0.0f32; logits.len()];
unsafe {
let mut max_vec = vdupq_n_f32(f32::NEG_INFINITY);
let chunks = logits.len() / UNROLL_8X;
for c in 0..chunks {
let base = c * UNROLL_8X;
let v0 = vld1q_f32(logits.as_ptr().add(base));
let v1 = vld1q_f32(logits.as_ptr().add(base + 4));
max_vec = vmaxq_f32(max_vec, vmaxq_f32(v0, v1));
}
let mut max_logit = vmaxvq_f32(max_vec);
for i in (chunks * UNROLL_8X)..logits.len() {
max_logit = max_logit.max(logits[i]);
}
let mut exp_sum = 0.0f32;
for i in 0..logits.len() {
exp_sum += (logits[i] - max_logit).exp();
}
let log_sum_exp = exp_sum.ln() + max_logit;
let log_sum_vec = vdupq_n_f32(log_sum_exp);
for c in 0..chunks {
let base = c * UNROLL_8X;
let v0 = vld1q_f32(logits.as_ptr().add(base));
let v1 = vld1q_f32(logits.as_ptr().add(base + 4));
let r0 = vsubq_f32(v0, log_sum_vec);
let r1 = vsubq_f32(v1, log_sum_vec);
vst1q_f32(result.as_mut_ptr().add(base), r0);
vst1q_f32(result.as_mut_ptr().add(base + 4), r1);
}
for i in (chunks * UNROLL_8X)..logits.len() {
result[i] = logits[i] - log_sum_exp;
}
}
result
}
pub fn sample_from_probs(probs: &[f32], rng: &mut impl Rng) -> usize {
let r: f32 = rng.gen();
let mut cumsum = 0.0;
for (i, &p) in probs.iter().enumerate() {
cumsum += p;
if cumsum > r {
return i;
}
}
probs.len() - 1
}
pub fn top_k_filter(logits: &mut [f32], k: usize) {
if k == 0 || k >= logits.len() {
return;
}
let mut indexed: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let threshold = indexed[k - 1].1;
for logit in logits.iter_mut() {
if *logit < threshold {
*logit = f32::NEG_INFINITY;
}
}
}
pub fn top_p_filter(logits: &mut [f32], p: f32) {
if p >= 1.0 {
return;
}
let probs = softmax(logits);
let mut indexed: Vec<(usize, f32)> = probs.iter().cloned().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let mut cumsum = 0.0;
let mut cutoff_idx = indexed.len();
for (i, (_, prob)) in indexed.iter().enumerate() {
cumsum += prob;
if cumsum > p {
cutoff_idx = i + 1;
break;
}
}
let included: std::collections::HashSet<usize> =
indexed[..cutoff_idx].iter().map(|(i, _)| *i).collect();
for (i, logit) in logits.iter_mut().enumerate() {
if !included.contains(&i) {
*logit = f32::NEG_INFINITY;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_speculative_config_default() {
let config = SpeculativeConfig::default();
assert_eq!(config.lookahead, 4);
assert!((config.acceptance_threshold - 0.5).abs() < 0.01);
assert!(!config.tree_speculation);
}
#[test]
fn test_speculative_stats() {
let mut stats = SpeculativeStats::new();
assert_eq!(stats.draft_tokens, 0);
assert_eq!(stats.accepted_tokens, 0);
stats.record_round(4, 3, 10.0);
assert_eq!(stats.draft_tokens, 4);
assert_eq!(stats.accepted_tokens, 3);
assert!((stats.acceptance_rate - 0.75).abs() < 0.01);
assert_eq!(stats.total_tokens_generated, 4); }
#[test]
fn test_atomic_stats() {
let stats = AtomicSpeculativeStats::new();
stats.record_round(4, 3, Duration::from_millis(10));
let snapshot = stats.snapshot();
assert_eq!(snapshot.draft_tokens, 4);
assert_eq!(snapshot.accepted_tokens, 3);
assert!((snapshot.acceptance_rate - 0.75).abs() < 0.01);
}
#[test]
fn test_tree_node() {
let mut root = TreeNode::new(0, 1.0, 0);
root.add_child(1, 0.5);
root.add_child(2, 0.3);
assert_eq!(root.children.len(), 2);
assert_eq!(root.children[0].token, 1);
assert_eq!(root.children[1].token, 2);
}
#[test]
fn test_speculation_tree() {
let mut tree = SpeculationTree::new(3, 2);
assert_eq!(tree.node_count, 1);
let current = &mut tree.root;
current.add_child(1, 0.8);
tree.node_count += 1;
assert_eq!(tree.node_count, 2);
}
#[test]
fn test_softmax() {
let logits = vec![1.0, 2.0, 3.0];
let probs = softmax(&logits);
let sum: f32 = probs.iter().sum();
assert!((sum - 1.0).abs() < 0.001);
assert!(probs[2] > probs[1]);
assert!(probs[1] > probs[0]);
}
#[test]
fn test_top_k_filter() {
let mut logits = vec![1.0, 5.0, 3.0, 4.0, 2.0];
top_k_filter(&mut logits, 2);
let finite_count = logits.iter().filter(|x| x.is_finite()).count();
assert_eq!(finite_count, 2);
}
#[test]
fn test_top_p_filter() {
let mut logits = vec![10.0, 5.0, 3.0, 2.0, 1.0];
top_p_filter(&mut logits, 0.9);
let finite_count = logits.iter().filter(|x| x.is_finite()).count();
assert!(finite_count >= 1);
}
#[test]
fn test_verification_result() {
let result = VerificationResult {
accepted_count: 3,
next_token: 42,
accepted_logprobs: vec![-0.1, -0.2, -0.3],
next_logprob: -0.5,
all_accepted: false,
};
assert_eq!(result.accepted_count, 3);
assert_eq!(result.next_token, 42);
assert!(!result.all_accepted);
}
}