use anyhow::{bail, Result};
use std::time::SystemTime;
use tch::vision::imagenet;
const NRUNS: i32 = 10;
pub fn main() -> Result<()> {
let args: Vec<_> = std::env::args().collect();
let (model_file, image_file, qengine) = match args.as_slice() {
[_, m, i, q] => (m.to_owned(), i.to_owned(), Some(q.to_owned())),
[_, m, i] => (m.to_owned(), i.to_owned(), None),
_ => bail!("usage: main model.pt image.jpg [fbgemm | qnnpack]"),
};
match qengine {
None => (),
Some(qengine) => match qengine.as_str() {
"fbgemm" => tch::QEngine::FBGEMM.set()?,
"qnnpack" => tch::QEngine::QNNPACK.set()?,
_ => bail!("qengine should be one of 'fbgemm' or 'qnnpack' or ommitted"),
},
};
let image = imagenet::load_image_and_resize224(image_file)?;
let model = tch::CModule::load(model_file)?;
let now = SystemTime::now();
for _ in 1..NRUNS {
let _output = image.unsqueeze(0).apply(&model);
}
println!("Mean Inference Time: {} ms", now.elapsed().unwrap().as_millis() / NRUNS as u128);
let output = image.unsqueeze(0).apply(&model).softmax(-1, tch::Kind::Float);
println!("Top 5 Predictions:");
for (probability, class) in imagenet::top(&output, 5).iter() {
println!("{:50} {:5.2}%", class, 100.0 * probability)
}
Ok(())
}