use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use oxillama_runtime::{sample, Sampler, SamplerConfig};
const VOCAB_SIZE: usize = 32_000;
fn make_logits(n: usize) -> Vec<f32> {
(0..n)
.map(|i| {
let h = (i as u64)
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
(h % 1000) as f32 / 100.0 - 5.0
})
.collect()
}
fn make_token_history(len: usize, vocab_size: usize) -> Vec<u32> {
(0..len)
.map(|i| {
let h = (i as u64)
.wrapping_mul(2_862_933_555_777_941_757)
.wrapping_add(3_037_000_499);
(h % vocab_size as u64) as u32
})
.collect()
}
fn bench_greedy(c: &mut Criterion) {
let logits = make_logits(VOCAB_SIZE);
let config = SamplerConfig::greedy();
c.bench_function("greedy_32k", |b| {
b.iter_batched(
|| logits.clone(),
|l| {
let result = sample(std::hint::black_box(&l), &config, &[]);
std::hint::black_box(result)
},
BatchSize::SmallInput,
)
});
}
fn bench_top_k(c: &mut Criterion) {
let logits = make_logits(VOCAB_SIZE);
let config = SamplerConfig {
temperature: 0.8,
top_k: 40,
top_p: 1.0,
min_p: 0.0,
repetition_penalty: 1.0,
repetition_penalty_window: 0,
seed: Some(42),
..SamplerConfig::default()
};
c.bench_function("top_k40_32k", |b| {
b.iter_batched(
|| logits.clone(),
|l| {
let result = sample(std::hint::black_box(&l), &config, &[]);
std::hint::black_box(result)
},
BatchSize::SmallInput,
)
});
}
fn bench_top_p(c: &mut Criterion) {
let logits = make_logits(VOCAB_SIZE);
let config = SamplerConfig {
temperature: 0.8,
top_k: 0,
top_p: 0.9,
min_p: 0.0,
repetition_penalty: 1.0,
repetition_penalty_window: 0,
seed: Some(42),
..SamplerConfig::default()
};
c.bench_function("top_p0.9_32k", |b| {
b.iter_batched(
|| logits.clone(),
|l| {
let result = sample(std::hint::black_box(&l), &config, &[]);
std::hint::black_box(result)
},
BatchSize::SmallInput,
)
});
}
fn bench_min_p(c: &mut Criterion) {
let logits = make_logits(VOCAB_SIZE);
let config = SamplerConfig {
temperature: 0.8,
top_k: 0,
top_p: 1.0,
min_p: 0.05,
repetition_penalty: 1.0,
repetition_penalty_window: 0,
seed: Some(42),
..SamplerConfig::default()
};
c.bench_function("min_p0.05_32k", |b| {
b.iter_batched(
|| logits.clone(),
|l| {
let result = sample(std::hint::black_box(&l), &config, &[]);
std::hint::black_box(result)
},
BatchSize::SmallInput,
)
});
}
fn bench_repetition_penalty(c: &mut Criterion) {
let logits = make_logits(VOCAB_SIZE);
let history = make_token_history(200, VOCAB_SIZE);
let config = SamplerConfig {
temperature: 0.8,
top_k: 40,
top_p: 0.9,
min_p: 0.0,
repetition_penalty: 1.1,
repetition_penalty_window: 200,
seed: Some(42),
..SamplerConfig::default()
};
c.bench_function("rep_penalty_200hist_32k", |b| {
b.iter_batched(
|| (logits.clone(), history.clone()),
|(l, h)| {
let result = sample(std::hint::black_box(&l), &config, std::hint::black_box(&h));
std::hint::black_box(result)
},
BatchSize::SmallInput,
)
});
}
fn bench_mirostat_v2(c: &mut Criterion) {
let logits = make_logits(VOCAB_SIZE);
let config = SamplerConfig {
seed: Some(42),
..SamplerConfig::mirostat_v2(5.0, 0.1)
};
c.bench_function("mirostat_v2_32k", |b| {
b.iter_batched(
|| (logits.clone(), Sampler::new(config.clone())),
|(l, mut sampler)| {
let result = sampler.sample(std::hint::black_box(&l), &[]);
std::hint::black_box(result)
},
BatchSize::SmallInput,
)
});
}
fn bench_combined(c: &mut Criterion) {
let logits = make_logits(VOCAB_SIZE);
let history = make_token_history(64, VOCAB_SIZE);
let config = SamplerConfig {
temperature: 0.7,
top_k: 40,
top_p: 0.9,
min_p: 0.05,
repetition_penalty: 1.1,
repetition_penalty_window: 64,
seed: Some(42),
..SamplerConfig::default()
};
c.bench_function("combined_default_32k", |b| {
b.iter_batched(
|| (logits.clone(), history.clone()),
|(l, h)| {
let result = sample(std::hint::black_box(&l), &config, std::hint::black_box(&h));
std::hint::black_box(result)
},
BatchSize::SmallInput,
)
});
}
criterion_group!(
benches,
bench_greedy,
bench_top_k,
bench_top_p,
bench_min_p,
bench_repetition_penalty,
bench_mirostat_v2,
bench_combined,
);
criterion_main!(benches);