#![cfg(any(test, feature = "test-util"))]
use crate::Result;
use std::sync::Arc;
pub type LogitFn = Arc<dyn Fn(&[u32], usize) -> Vec<f32> + Send + Sync>;
#[derive(Clone)]
pub struct MockDecoder {
vocab_size: usize,
history: Vec<u32>,
logits_fn: LogitFn,
}
impl std::fmt::Debug for MockDecoder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MockDecoder")
.field("vocab_size", &self.vocab_size)
.field("history_len", &self.history.len())
.finish()
}
}
impl MockDecoder {
pub fn new(vocab_size: usize, logits_fn: LogitFn) -> Self {
Self {
vocab_size,
history: Vec::new(),
logits_fn,
}
}
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
pub fn reset(&mut self) {
self.history.clear();
}
pub fn history(&self) -> &[u32] {
&self.history
}
pub fn observe(&mut self, tokens: &[u32]) {
self.history.extend_from_slice(tokens);
}
pub fn rollback_to(&mut self, len: usize) {
debug_assert!(len <= self.history.len(), "cannot roll forward");
self.history.truncate(len);
}
pub fn next_logits(&self) -> Vec<f32> {
let pos = self.history.len();
let raw = (self.logits_fn)(&self.history, pos);
debug_assert_eq!(
raw.len(),
self.vocab_size,
"logit_fn returned {} logits but vocab_size is {}",
raw.len(),
self.vocab_size
);
raw
}
pub fn batched_logits(&self, drafts: &[u32]) -> Vec<Vec<f32>> {
let mut out = Vec::with_capacity(drafts.len() + 1);
let mut prefix: Vec<u32> = self.history.clone();
out.push((self.logits_fn)(&prefix, prefix.len()));
for &d in drafts {
prefix.push(d);
out.push((self.logits_fn)(&prefix, prefix.len()));
}
out
}
}
impl super::TreeDecoder for MockDecoder {}
impl super::Decoder for MockDecoder {
fn vocab_size(&self) -> usize {
self.vocab_size
}
fn history(&self) -> &[u32] {
&self.history
}
fn reset(&mut self) {
self.history.clear();
}
fn observe(&mut self, ids: &[u32]) -> Result<()> {
self.history.extend_from_slice(ids);
Ok(())
}
fn next_logits(&mut self) -> Result<Vec<f32>> {
Ok(MockDecoder::next_logits(self))
}
fn batched_logits(&mut self, drafts: &[u32]) -> Result<Vec<Vec<f32>>> {
let out = MockDecoder::batched_logits(self, drafts);
self.history.extend_from_slice(drafts);
Ok(out)
}
fn rollback_to(&mut self, len: usize) -> Result<()> {
if len > self.history.len() {
return Err(crate::Error::CacheRollback(format!(
"rollback target {len} > history length {}",
self.history.len()
)));
}
self.history.truncate(len);
Ok(())
}
}
pub fn fixed_distribution(probs: Vec<f32>) -> MockDecoder {
let vocab = probs.len();
let logits: Vec<f32> = probs
.iter()
.map(|&p| {
(p.max(1e-30)).ln()
})
.collect();
let logits = Arc::new(logits);
let f: LogitFn = {
let logits = Arc::clone(&logits);
Arc::new(move |_history, _pos| (*logits).clone())
};
MockDecoder::new(vocab, f)
}