extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use tch::Device;
use failure::err_msg;
use rust_bert::pipelines::generation::{GPT2Generator, LanguageGenerator, GenerateConfig};
fn main() -> failure::Fallible<()> {
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("gpt2");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let merges_path = &home.as_path().join("merges.txt");
let weights_path = &home.as_path().join("model.ot");
if !config_path.is_file() | !vocab_path.is_file() | !merges_path.is_file() | !weights_path.is_file() {
return Err(
err_msg("Could not find required resources to run example. \
Please run ../utils/download_dependencies_gpt2.py \
in a Python environment with dependencies listed in ../requirements.txt"));
}
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
max_length: 30,
do_sample: true,
num_beams: 5,
temperature: 1.1,
num_return_sequences: 3,
..Default::default()
};
let model = GPT2Generator::new(vocab_path, merges_path, config_path, weights_path,
generate_config, device)?;
let input_context = "The dog";
let second_input_context = "The cat was";
let output = model.generate(Some(vec!(input_context, second_input_context)), None);
for sentence in output {
println!("{:?}", sentence);
}
Ok(())
}