use std::{hint::black_box, sync::Arc};
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use llm_tokenizer::{
mock::MockTokenizer,
sequence::Sequence,
stop::{StopSequenceConfig, StopSequenceDecoder},
};
fn bench_stop_decoder_process_tokens(c: &mut Criterion) {
let tokenizer = Arc::new(MockTokenizer::new());
let mut group = c.benchmark_group("stop_decoder_process_tokens");
let token_cycle: Vec<u32> = vec![1, 2, 3, 4, 6];
for num_stops in [1, 5, 10] {
let stop_sequences: Vec<String> = (0..num_stops)
.map(|i| format!("STOP_SEQUENCE_{i}_END"))
.collect();
for num_tokens in [100, 500, 2000] {
let tokens: Vec<u32> = token_cycle
.iter()
.cycle()
.take(num_tokens)
.copied()
.collect();
group.bench_with_input(
BenchmarkId::new(format!("{num_stops}_stops"), format!("{num_tokens}_tokens")),
&tokens,
|b, tokens| {
b.iter(|| {
let mut config = StopSequenceConfig::default();
for s in &stop_sequences {
config = config.with_stop_sequence(s);
}
let mut decoder =
StopSequenceDecoder::new(tokenizer.clone(), config, false);
for &token_id in tokens {
let _ = black_box(decoder.process_token(token_id));
}
});
},
);
}
}
group.finish();
}
fn bench_stop_decoder_with_match(c: &mut Criterion) {
let tokenizer = Arc::new(MockTokenizer::new());
let mut group = c.benchmark_group("stop_decoder_with_match");
for num_tokens_before_stop in [10, 100, 500] {
let tokens_before: Vec<u32> = vec![3, 4, 6] .into_iter()
.cycle()
.take(num_tokens_before_stop)
.collect();
group.bench_with_input(
BenchmarkId::from_parameter(format!("{num_tokens_before_stop}_before_match")),
&tokens_before,
|b, tokens_before| {
b.iter(|| {
let config = StopSequenceConfig::default().with_stop_sequence("Hello world");
let mut decoder = StopSequenceDecoder::new(tokenizer.clone(), config, false);
for &token_id in tokens_before {
let _ = black_box(decoder.process_token(token_id));
}
let _ = black_box(decoder.process_token(1)); let _ = black_box(decoder.process_token(2)); });
},
);
}
group.finish();
}
fn bench_stop_decoder_token_only(c: &mut Criterion) {
let tokenizer = Arc::new(MockTokenizer::new());
let mut group = c.benchmark_group("stop_decoder_token_only");
let token_cycle: Vec<u32> = vec![1, 2, 3, 4, 6];
for num_tokens in [100, 500, 2000] {
let tokens: Vec<u32> = token_cycle
.iter()
.cycle()
.take(num_tokens)
.copied()
.collect();
group.bench_with_input(
BenchmarkId::from_parameter(format!("{num_tokens}_tokens")),
&tokens,
|b, tokens| {
b.iter(|| {
let config = StopSequenceConfig::default().with_stop_token(999);
let mut decoder = StopSequenceDecoder::new(tokenizer.clone(), config, false);
for &token_id in tokens {
let _ = black_box(decoder.process_token(token_id));
}
});
},
);
}
group.finish();
}
fn bench_sequence_append_token(c: &mut Criterion) {
let tokenizer = Arc::new(MockTokenizer::new());
let mut group = c.benchmark_group("sequence_append_token");
let token_cycle: Vec<u32> = vec![1, 2, 3, 4, 6];
for num_tokens in [100, 500, 2000] {
let tokens: Vec<u32> = token_cycle
.iter()
.cycle()
.take(num_tokens)
.copied()
.collect();
group.bench_with_input(
BenchmarkId::from_parameter(format!("{num_tokens}_tokens")),
&tokens,
|b, tokens| {
b.iter(|| {
let mut seq = Sequence::new(tokenizer.clone());
for &token_id in tokens {
let _ = black_box(seq.append_token(token_id));
}
});
},
);
}
group.finish();
}
fn bench_per_token_at_position(c: &mut Criterion) {
let tokenizer = Arc::new(MockTokenizer::new());
let mut group = c.benchmark_group("per_token_at_position");
let token_cycle: Vec<u32> = vec![1, 2, 3, 4, 6];
for position in [100, 500, 1000, 2000, 5000] {
let prefill_tokens: Vec<u32> = token_cycle.iter().cycle().take(position).copied().collect();
let stop_sequences: Vec<String> =
(0..5).map(|i| format!("STOP_SEQUENCE_{i}_END")).collect();
group.bench_with_input(
BenchmarkId::new("5_stops", format!("pos_{position}")),
&prefill_tokens,
|b, prefill_tokens| {
b.iter_batched(
|| {
let mut config = StopSequenceConfig::default();
for s in &stop_sequences {
config = config.with_stop_sequence(s);
}
let mut decoder =
StopSequenceDecoder::new(tokenizer.clone(), config, false);
for &token_id in prefill_tokens {
let _ = decoder.process_token(token_id);
}
decoder
},
|mut decoder| {
let _ = black_box(decoder.process_token(3));
},
criterion::BatchSize::SmallInput,
);
},
);
group.bench_with_input(
BenchmarkId::new("token_only", format!("pos_{position}")),
&prefill_tokens,
|b, prefill_tokens| {
b.iter_batched(
|| {
let config = StopSequenceConfig::default().with_stop_token(999);
let mut decoder =
StopSequenceDecoder::new(tokenizer.clone(), config, false);
for &token_id in prefill_tokens {
let _ = decoder.process_token(token_id);
}
decoder
},
|mut decoder| {
let _ = black_box(decoder.process_token(3));
},
criterion::BatchSize::SmallInput,
);
},
);
}
group.finish();
}
criterion_group!(
benches,
bench_stop_decoder_process_tokens,
bench_stop_decoder_with_match,
bench_stop_decoder_token_only,
bench_sequence_append_token,
bench_per_token_at_position,
);
criterion_main!(benches);