use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct Xorshift64 {
state: u64,
}
impl Xorshift64 {
pub fn new(seed: u64) -> Self {
Self {
state: if seed == 0 {
0x853c_49e6_748f_ea9b
} else {
seed
},
}
}
pub fn next_u64(&mut self) -> u64 {
let mut x = self.state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.state = x;
x
}
pub fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
}
}
pub trait DraftModel: Send + Sync {
fn generate_draft(&self, context: &[usize], length: usize) -> Vec<(usize, f64)>;
fn vocab_size(&self) -> usize;
}
#[derive(Debug, Clone)]
pub struct NGramDraftModel {
table: HashMap<Vec<usize>, HashMap<usize, usize>>,
n: usize,
vocab: usize,
seed: u64,
}
impl NGramDraftModel {
pub fn new(corpus: &[usize], n: usize, vocab_size: usize, seed: u64) -> Option<Self> {
if n < 2 || vocab_size == 0 {
return None;
}
let table = build_ngram_table(corpus, n);
Some(Self {
table,
n,
vocab: vocab_size,
seed,
})
}
fn distribution_for_context(&self, context: &[usize]) -> Vec<f64> {
let prefix_len = self.n - 1;
let prefix = if context.len() >= prefix_len {
&context[context.len() - prefix_len..]
} else {
context
};
if let Some(counts) = self.table.get(prefix) {
let total: usize = counts.values().sum();
if total == 0 {
return vec![1.0 / self.vocab as f64; self.vocab];
}
let mut probs = vec![0.0; self.vocab];
for (&token, &count) in counts {
if token < self.vocab {
probs[token] = count as f64 / total as f64;
}
}
let smoothing = 1e-10;
let sum_before: f64 = probs.iter().sum();
for p in &mut probs {
*p += smoothing;
}
let sum_after: f64 = probs.iter().sum();
if sum_after > 0.0 && sum_before >= 0.0 {
for p in &mut probs {
*p /= sum_after;
}
}
probs
} else {
vec![1.0 / self.vocab as f64; self.vocab]
}
}
fn sample_from_probs(probs: &[f64], rng: &mut Xorshift64) -> (usize, f64) {
let u = rng.next_f64();
let mut cumulative = 0.0;
for (i, &p) in probs.iter().enumerate() {
cumulative += p;
if u < cumulative {
return (i, p);
}
}
let last = probs.len().saturating_sub(1);
(last, probs.get(last).copied().unwrap_or(0.0))
}
}
impl DraftModel for NGramDraftModel {
fn generate_draft(&self, context: &[usize], length: usize) -> Vec<(usize, f64)> {
let mut rng = Xorshift64::new(self.seed);
let mut tokens = Vec::with_capacity(length);
let mut ctx: Vec<usize> = context.to_vec();
for _ in 0..length {
let probs = self.distribution_for_context(&ctx);
let (token, prob) = Self::sample_from_probs(&probs, &mut rng);
tokens.push((token, prob));
ctx.push(token);
}
tokens
}
fn vocab_size(&self) -> usize {
self.vocab
}
}
#[derive(Debug, Clone)]
pub struct UniformDraftModel {
vocab: usize,
seed: u64,
}
impl UniformDraftModel {
pub fn new(vocab_size: usize, seed: u64) -> Option<Self> {
if vocab_size == 0 {
return None;
}
Some(Self {
vocab: vocab_size,
seed,
})
}
}
impl DraftModel for UniformDraftModel {
fn generate_draft(&self, _context: &[usize], length: usize) -> Vec<(usize, f64)> {
let mut rng = Xorshift64::new(self.seed);
let p = 1.0 / self.vocab as f64;
(0..length)
.map(|_| {
let token = (rng.next_u64() as usize) % self.vocab;
(token, p)
})
.collect()
}
fn vocab_size(&self) -> usize {
self.vocab
}
}
pub fn build_ngram_table(corpus: &[usize], n: usize) -> HashMap<Vec<usize>, HashMap<usize, usize>> {
let mut table: HashMap<Vec<usize>, HashMap<usize, usize>> = HashMap::new();
if n < 2 || corpus.len() < n {
return table;
}
for window in corpus.windows(n) {
let prefix = window[..n - 1].to_vec();
let next_token = window[n - 1];
let entry = table.entry(prefix).or_default();
*entry.entry(next_token).or_insert(0) += 1;
}
table
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_xorshift64_produces_values_in_range() {
let mut rng = Xorshift64::new(42);
for _ in 0..1000 {
let v = rng.next_f64();
assert!((0.0..1.0).contains(&v), "value out of range: {v}");
}
}
#[test]
fn test_xorshift64_zero_seed_uses_default() {
let mut rng = Xorshift64::new(0);
let v = rng.next_u64();
assert_ne!(v, 0);
}
#[test]
fn test_build_ngram_table_basic() {
let corpus = vec![0, 1, 2, 0, 1, 3];
let table = build_ngram_table(&corpus, 2);
let counts_0 = table.get(&vec![0]);
assert!(counts_0.is_some());
let counts_0 = counts_0.expect("test: prefix [0] should exist");
assert_eq!(counts_0.get(&1).copied().unwrap_or(0), 2);
let counts_1 = table.get(&vec![1]);
assert!(counts_1.is_some());
let counts_1 = counts_1.expect("test: prefix [1] should exist");
assert_eq!(counts_1.get(&2).copied().unwrap_or(0), 1);
assert_eq!(counts_1.get(&3).copied().unwrap_or(0), 1);
}
#[test]
fn test_build_ngram_table_too_short() {
let corpus = vec![0];
let table = build_ngram_table(&corpus, 2);
assert!(table.is_empty());
}
#[test]
fn test_ngram_draft_generates_correct_length() {
let corpus = vec![0, 1, 2, 3, 0, 1, 2, 3, 0, 1];
let model =
NGramDraftModel::new(&corpus, 2, 4, 123).expect("test: should build ngram model");
let draft = model.generate_draft(&[0], 5);
assert_eq!(draft.len(), 5);
}
#[test]
fn test_ngram_draft_probabilities_positive() {
let corpus = vec![0, 1, 2, 3, 0, 1, 2, 3];
let model =
NGramDraftModel::new(&corpus, 2, 4, 42).expect("test: should build ngram model");
let draft = model.generate_draft(&[0], 3);
for (_, prob) in &draft {
assert!(*prob > 0.0, "draft probability should be positive");
}
}
#[test]
fn test_ngram_draft_from_known_distribution() {
let corpus = vec![0, 1, 0, 1, 0, 1, 0, 1, 0, 1];
let model =
NGramDraftModel::new(&corpus, 2, 4, 99).expect("test: should build ngram model");
let draft = model.generate_draft(&[0], 1);
let (token, prob) = draft[0];
assert_eq!(token, 1, "after 0 should always predict 1");
assert!(prob > 0.9, "probability should be near 1.0, got {prob}");
}
#[test]
fn test_uniform_draft_correct_length() {
let model = UniformDraftModel::new(100, 42).expect("test: should build uniform model");
let draft = model.generate_draft(&[0, 1, 2], 7);
assert_eq!(draft.len(), 7);
}
#[test]
fn test_uniform_draft_probability() {
let vocab = 50;
let model = UniformDraftModel::new(vocab, 42).expect("test: should build uniform model");
let draft = model.generate_draft(&[], 10);
let expected_p = 1.0 / vocab as f64;
for (_, prob) in &draft {
assert!(
(*prob - expected_p).abs() < 1e-12,
"uniform prob should be {expected_p}, got {prob}"
);
}
}
#[test]
fn test_uniform_draft_tokens_in_range() {
let vocab = 20;
let model = UniformDraftModel::new(vocab, 12345).expect("test: should build uniform model");
let draft = model.generate_draft(&[], 100);
for (token, _) in &draft {
assert!(*token < vocab, "token {token} out of vocab range {vocab}");
}
}
#[test]
fn test_ngram_invalid_params() {
assert!(NGramDraftModel::new(&[0, 1], 1, 10, 0).is_none()); assert!(NGramDraftModel::new(&[0, 1], 2, 0, 0).is_none()); }
#[test]
fn test_uniform_invalid_params() {
assert!(UniformDraftModel::new(0, 42).is_none());
}
}