#[macro_use]
extern crate criterion;
use criterion::{black_box, Criterion};
use rust_bert::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::RemoteResource;
use std::time::{Duration, Instant};
use tch::Device;
fn create_text_generation_model() -> TextGenerationModel {
let config = TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
Gpt2ModelResources::GPT2,
))),
config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
merges_resource: Some(Box::new(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2,
))),
min_length: 0,
max_length: Some(30),
do_sample: true,
early_stopping: false,
num_beams: 5,
temperature: 1.0,
top_k: 0,
top_p: 0.9,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 3,
num_beam_groups: None,
diversity_penalty: None,
num_return_sequences: 5,
device: Device::cuda_if_available(),
kind: None,
};
TextGenerationModel::new(config).unwrap()
}
fn generation_forward_pass(iters: u64, model: &TextGenerationModel, data: &[&str]) -> Duration {
let mut duration = Duration::new(0, 0);
for _i in 0..iters {
let start = Instant::now();
let _ = model.generate(data, None);
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
}
fn bench_generation(c: &mut Criterion) {
let model = create_text_generation_model();
let input = ["Hello, I'm a language model,"];
c.bench_function("Generation", |b| {
b.iter_custom(|iters| black_box(generation_forward_pass(iters, &model, &input)))
});
}
criterion_group! {
name = benches;
config = Criterion::default().sample_size(10);
targets = bench_generation
}
criterion_main!(benches);