use std::{
fs,
io::{self, BufRead, BufReader},
path::Path,
time::Duration,
};
use onnxruntime::error::OrtDownloadError;
mod download {
use super::*;
use image::{imageops::FilterType, ImageBuffer, Luma, Pixel, Rgb};
use ndarray::s;
use test_env_log::test;
use onnxruntime::{
download::vision::{DomainBasedImageClassification, ImageClassification},
environment::Environment,
GraphOptimizationLevel, LoggingLevel,
};
#[test]
fn squeezenet_mushroom() {
const IMAGE_TO_LOAD: &str = "mushroom.png";
let environment = Environment::builder()
.with_name("integration_test")
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
let mut session = environment
.new_session_builder()
.unwrap()
.with_optimization_level(GraphOptimizationLevel::Basic)
.unwrap()
.with_number_threads(1)
.unwrap()
.with_model_downloaded(ImageClassification::SqueezeNet)
.expect("Could not download model from file");
let class_labels = get_imagenet_labels().unwrap();
let input0_shape: Vec<usize> = session.inputs[0].dimensions().map(|d| d.unwrap()).collect();
let output0_shape: Vec<usize> = session.outputs[0]
.dimensions()
.map(|d| d.unwrap())
.collect();
assert_eq!(input0_shape, [1, 3, 224, 224]);
assert_eq!(output0_shape, [1, 1000]);
let image_buffer: ImageBuffer<Rgb<u8>, Vec<u8>> = image::open(
Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("data")
.join(IMAGE_TO_LOAD),
)
.unwrap()
.resize(
input0_shape[2] as u32,
input0_shape[3] as u32,
FilterType::Nearest,
)
.to_rgb8();
let mut array = ndarray::Array::from_shape_fn((1, 3, 224, 224), |(_, c, j, i)| {
let pixel = image_buffer.get_pixel(i as u32, j as u32);
let channels = pixel.channels();
(channels[c] as f32) / 255.0
});
let mean = [0.485, 0.456, 0.406];
let std = [0.229, 0.224, 0.225];
for c in 0..3 {
let mut channel_array = array.slice_mut(s![0, c, .., ..]);
channel_array -= mean[c];
channel_array /= std[c];
}
let input_tensor_values = vec![array];
let outputs: Vec<
onnxruntime::tensor::OrtOwnedTensor<f32, ndarray::Dim<ndarray::IxDynImpl>>,
> = session.run(input_tensor_values).unwrap();
let mut probabilities: Vec<(usize, f32)> = outputs[0]
.softmax(ndarray::Axis(1))
.into_iter()
.copied()
.enumerate()
.collect::<Vec<_>>();
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
assert_eq!(
class_labels[probabilities[0].0], "n07734744 mushroom",
"Expecting class for {} to be a mushroom",
IMAGE_TO_LOAD
);
assert_eq!(
probabilities[0].0, 947,
"Expecting class for {} to be a mushroom (index 947 in labels file)",
IMAGE_TO_LOAD
);
}
#[test]
fn mnist_5() {
const IMAGE_TO_LOAD: &str = "mnist_5.jpg";
let environment = Environment::builder()
.with_name("integration_test")
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
let mut session = environment
.new_session_builder()
.unwrap()
.with_optimization_level(GraphOptimizationLevel::Basic)
.unwrap()
.with_number_threads(1)
.unwrap()
.with_model_downloaded(DomainBasedImageClassification::Mnist)
.expect("Could not download model from file");
let input0_shape: Vec<usize> = session.inputs[0].dimensions().map(|d| d.unwrap()).collect();
let output0_shape: Vec<usize> = session.outputs[0]
.dimensions()
.map(|d| d.unwrap())
.collect();
assert_eq!(input0_shape, [1, 1, 28, 28]);
assert_eq!(output0_shape, [1, 10]);
let image_buffer: ImageBuffer<Luma<u8>, Vec<u8>> = image::open(
Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("data")
.join(IMAGE_TO_LOAD),
)
.unwrap()
.resize(
input0_shape[2] as u32,
input0_shape[3] as u32,
FilterType::Nearest,
)
.to_luma8();
let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| {
let pixel = image_buffer.get_pixel(i as u32, j as u32);
let channels = pixel.channels();
(channels[c] as f32) / 255.0
});
let input_tensor_values = vec![array];
let outputs: Vec<
onnxruntime::tensor::OrtOwnedTensor<f32, ndarray::Dim<ndarray::IxDynImpl>>,
> = session.run(input_tensor_values).unwrap();
let mut probabilities: Vec<(usize, f32)> = outputs[0]
.softmax(ndarray::Axis(1))
.into_iter()
.copied()
.enumerate()
.collect::<Vec<_>>();
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
assert_eq!(
probabilities[0].0, 5,
"Expecting class for {} is '5' (not {})",
IMAGE_TO_LOAD, probabilities[0].0
);
}
#[test]
fn upsample() {
const IMAGE_TO_LOAD: &str = "mushroom.png";
let environment = Environment::builder()
.with_name("integration_test")
.with_log_level(LoggingLevel::Warning)
.build()
.unwrap();
let mut session = environment
.new_session_builder()
.unwrap()
.with_optimization_level(GraphOptimizationLevel::Basic)
.unwrap()
.with_number_threads(1)
.unwrap()
.with_model_from_file(
Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("data")
.join("upsample.onnx"),
)
.expect("Could not open model from file");
assert_eq!(
session.inputs[0].dimensions().collect::<Vec<_>>(),
[None, None, None, Some(3)]
);
assert_eq!(
session.outputs[0].dimensions().collect::<Vec<_>>(),
[None, None, None, Some(3)]
);
let image_buffer: ImageBuffer<Rgb<u8>, Vec<u8>> = image::open(
Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("data")
.join(IMAGE_TO_LOAD),
)
.unwrap()
.to_rgb8();
let array = ndarray::Array::from_shape_fn((1, 224, 224, 3), |(_, j, i, c)| {
let pixel = image_buffer.get_pixel(i as u32, j as u32);
let channels = pixel.channels();
(channels[c] as f32) / 255.0
});
let input_tensor_values = vec![array];
let outputs: Vec<
onnxruntime::tensor::OrtOwnedTensor<f32, ndarray::Dim<ndarray::IxDynImpl>>,
> = session.run(input_tensor_values).unwrap();
assert_eq!(outputs.len(), 1);
let output = &outputs[0];
assert_eq!(output.shape(), [1, 448, 448, 3]);
}
}
fn get_imagenet_labels() -> Result<Vec<String>, OrtDownloadError> {
let labels_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("synset.txt");
if !labels_path.exists() {
let url = "https://s3.amazonaws.com/onnx-model-zoo/synset.txt";
println!("Downloading {:?} to {:?}...", url, labels_path);
let resp = ureq::get(url)
.timeout(Duration::from_secs(180)) .call()
.map_err(Box::new)
.map_err(OrtDownloadError::UreqError)?;
assert!(resp.has("Content-Length"));
let len = resp
.header("Content-Length")
.and_then(|s| s.parse::<usize>().ok())
.unwrap();
println!("Downloading {} bytes...", len);
let mut reader = resp.into_reader();
let f = fs::File::create(&labels_path).unwrap();
let mut writer = io::BufWriter::new(f);
let bytes_io_count = io::copy(&mut reader, &mut writer).unwrap();
assert_eq!(bytes_io_count, len as u64);
}
let file = BufReader::new(fs::File::open(labels_path).unwrap());
file.lines()
.map(|line| line.map_err(|io_err| OrtDownloadError::IoError(io_err)))
.collect()
}