use std::fs::File;
use std::io::{BufRead, BufReader};
use std::time;
use anyhow::Result;
use clap::Parser;
use ct2rs::{Config, Device, Translator};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(short, long, value_name = "FILE", default_value = "prompt.txt")]
prompt: String,
#[arg(short, long)]
cuda: bool,
path: String,
}
fn main() -> Result<()> {
let args = Args::parse();
let cfg = if args.cuda {
Config {
device: Device::CUDA,
device_indices: vec![0],
..Config::default()
}
} else {
Config::default()
};
let t = Translator::new(&args.path, &cfg)?;
let source = BufReader::new(File::open(args.prompt)?).lines().try_fold(
String::new(),
|mut acc, line| {
line.map(|l| {
acc.push_str(&l);
acc
})
},
)?;
let now = time::Instant::now();
let res = t.translate_batch(&[source], &Default::default(), None)?;
let elapsed = now.elapsed();
for (r, _) in res {
println!("{}", r.replace("<s>", ""));
}
println!("Time taken: {elapsed:?}");
Ok(())
}