use crate::error::{Error, Result};
use crate::ops::traits::inference::speculative::VerificationResult;
use numr::dtype::DType;
use numr::ops::{AdvancedRandomOps, BinaryOps, ScalarOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
fn sample_with_uniform(probs: &[f32], r: f32) -> u32 {
let mut cumsum = 0.0f32;
for (i, &p) in probs.iter().enumerate() {
cumsum += p;
if r < cumsum {
return i as u32;
}
}
(probs.len() - 1) as u32
}
pub fn verify_speculative_tokens_impl<R, C>(
client: &C,
draft_probs: &Tensor<R>,
target_probs: &Tensor<R>,
draft_tokens: &Tensor<R>,
seed: u64,
) -> Result<Vec<VerificationResult>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + AdvancedRandomOps<R>,
{
let dp_shape = draft_probs.shape();
let tp_shape = target_probs.shape();
let dt_shape = draft_tokens.shape();
if dp_shape.len() != 3 {
return Err(Error::InvalidArgument {
arg: "draft_probs",
reason: format!("expected 3D [batch, K, vocab], got {}D", dp_shape.len()),
});
}
if tp_shape.len() != 3 {
return Err(Error::InvalidArgument {
arg: "target_probs",
reason: format!("expected 3D [batch, K+1, vocab], got {}D", tp_shape.len()),
});
}
if dt_shape.len() != 2 {
return Err(Error::InvalidArgument {
arg: "draft_tokens",
reason: format!("expected 2D [batch, K], got {}D", dt_shape.len()),
});
}
let batch_size = dp_shape[0];
let num_draft = dp_shape[1];
let vocab_size = dp_shape[2];
if tp_shape[0] != batch_size || tp_shape[1] != num_draft + 1 || tp_shape[2] != vocab_size {
return Err(Error::InvalidArgument {
arg: "target_probs",
reason: format!(
"expected [{}, {}, {}], got {:?}",
batch_size,
num_draft + 1,
vocab_size,
tp_shape
),
});
}
if dt_shape[0] != batch_size || dt_shape[1] != num_draft {
return Err(Error::InvalidArgument {
arg: "draft_tokens",
reason: format!(
"expected [{}, {}], got {:?}",
batch_size, num_draft, dt_shape
),
});
}
if draft_probs.dtype() != DType::F32 {
return Err(Error::InvalidArgument {
arg: "draft_probs",
reason: format!("expected F32, got {:?}", draft_probs.dtype()),
});
}
if draft_tokens.dtype() != DType::I32 {
return Err(Error::InvalidArgument {
arg: "draft_tokens",
reason: format!("expected I32, got {:?}", draft_tokens.dtype()),
});
}
let randoms_tensor = client
.philox_uniform(&[batch_size, num_draft + 1], seed, 0, DType::F32)
.map_err(Error::Numr)?;
let dp_data = draft_probs.contiguous().to_vec::<f32>();
let tp_data = target_probs.contiguous().to_vec::<f32>();
let dt_data = draft_tokens.contiguous().to_vec::<i32>();
let rand_data = randoms_tensor.to_vec::<f32>();
let mut results = Vec::with_capacity(batch_size);
for b in 0..batch_size {
let rand_base = b * (num_draft + 1);
let mut accepted_tokens = Vec::new();
let mut first_rejection_pos = -1i32;
for pos in 0..num_draft {
let draft_token = dt_data[b * num_draft + pos] as usize;
let draft_prob = dp_data[b * num_draft * vocab_size + pos * vocab_size + draft_token];
let target_prob =
tp_data[b * (num_draft + 1) * vocab_size + pos * vocab_size + draft_token];
let accept_prob = if draft_prob > 0.0 {
(target_prob / draft_prob).min(1.0)
} else {
1.0
};
if rand_data[rand_base + pos] < accept_prob {
accepted_tokens.push(draft_token as u32);
} else {
first_rejection_pos = pos as i32;
break;
}
}
let bonus_r = rand_data[rand_base + num_draft];
let bonus_token = if first_rejection_pos >= 0 {
let pos = first_rejection_pos as usize;
let dp_offset = b * num_draft * vocab_size + pos * vocab_size;
let tp_offset = b * (num_draft + 1) * vocab_size + pos * vocab_size;
let residual: Vec<f32> = (0..vocab_size)
.map(|v| (tp_data[tp_offset + v] - dp_data[dp_offset + v]).max(0.0))
.collect();
let sum: f32 = residual.iter().sum();
if sum > 1e-8 {
sample_with_uniform(&residual, bonus_r * sum)
} else {
let target_slice = &tp_data[tp_offset..tp_offset + vocab_size];
let target_sum: f32 = target_slice.iter().sum();
sample_with_uniform(target_slice, bonus_r * target_sum)
}
} else {
let tp_offset = b * (num_draft + 1) * vocab_size + num_draft * vocab_size;
let target_slice = &tp_data[tp_offset..tp_offset + vocab_size];
let target_sum: f32 = target_slice.iter().sum();
sample_with_uniform(target_slice, bonus_r * target_sum)
};
results.push(VerificationResult {
num_accepted: accepted_tokens.len(),
accepted_tokens,
bonus_token,
first_rejection_pos,
});
}
Ok(results)
}
pub fn compute_acceptance_probs_impl<R, C>(
client: &C,
draft_probs: &Tensor<R>,
target_probs: &Tensor<R>,
) -> Result<(Tensor<R>, Tensor<R>)>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + BinaryOps<R> + ScalarOps<R>,
{
let dp_shape = draft_probs.shape();
let tp_shape = target_probs.shape();
if dp_shape != tp_shape {
return Err(Error::InvalidArgument {
arg: "target_probs",
reason: format!(
"shape mismatch: draft {:?} vs target {:?}",
dp_shape, tp_shape
),
});
}
let eps = Tensor::<R>::full_scalar(dp_shape, DType::F32, 1e-10, draft_probs.device());
let ones = Tensor::<R>::full_scalar(dp_shape, DType::F32, 1.0, draft_probs.device());
let zeros = Tensor::<R>::full_scalar(dp_shape, DType::F32, 0.0, draft_probs.device());
let safe_draft = client.maximum(draft_probs, &eps).map_err(Error::Numr)?;
let ratio = client.div(target_probs, &safe_draft).map_err(Error::Numr)?;
let acceptance = client.minimum(&ratio, &ones).map_err(Error::Numr)?;
let diff = client.sub(target_probs, draft_probs).map_err(Error::Numr)?;
let residual = client.maximum(&diff, &zeros).map_err(Error::Numr)?;
Ok((acceptance, residual))
}
pub fn compute_expected_tokens_impl<R, C>(
_client: &C,
acceptance_rates: &Tensor<R>,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>,
{
let shape = acceptance_rates.shape();
if shape.len() != 2 {
return Err(Error::InvalidArgument {
arg: "acceptance_rates",
reason: format!("expected 2D [batch, K], got {}D", shape.len()),
});
}
let batch_size = shape[0];
let k = shape[1];
let rates = acceptance_rates.contiguous().to_vec::<f32>();
let mut expected = Vec::with_capacity(batch_size);
for b in 0..batch_size {
let mut cumulative_prob = 1.0f32;
let mut exp_tokens = 0.0f32;
for i in 0..k {
cumulative_prob *= rates[b * k + i];
exp_tokens += cumulative_prob;
}
expected.push(exp_tokens + 1.0);
}
Ok(Tensor::<R>::from_slice(
&expected,
&[batch_size],
acceptance_rates.device(),
))
}