use std::fs::File;
use std::io::{BufRead, BufReader};
use std::time;
use anyhow::Result;
use clap::Parser;
use ct2rs::{Config, Device, GenerationOptions, Generator};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(short, long, value_name = "FILE", default_value = "prompt.txt")]
prompt: String,
#[arg(short, long)]
cuda: bool,
path: String,
}
fn main() -> Result<()> {
let args = Args::parse();
let cfg = if args.cuda {
Config {
device: Device::CUDA,
device_indices: vec![0],
..Config::default()
}
} else {
Config::default()
};
let g = Generator::new(&args.path, &cfg)?;
let prompts = BufReader::new(File::open(args.prompt)?).lines().try_fold(
String::new(),
|mut acc, line| {
line.map(|l| {
acc.push_str(&l);
acc
})
},
)?;
let now = time::Instant::now();
let res = g.generate_batch(
&[prompts],
&GenerationOptions {
max_length: 30,
sampling_topk: 10,
..GenerationOptions::default()
},
None,
)?;
let elapsed = now.elapsed();
for (r, _) in res {
println!("{}", r.join("\n"));
}
println!("Time taken: {elapsed:?}");
Ok(())
}