ort 2.0.0-rc.12

A safe Rust wrapper for ONNX Runtime 1.24 - Optimize and accelerate machine learning inference & training
Documentation
#[cfg(feature = "ndarray")]
pub mod mnist {
	use image::{ImageBuffer, Luma, Pixel};
	use ndarray::{Array4, Axis};

	use crate::{
		Result,
		util::ArrayExt,
		value::{TensorValueTypeMarker, Value}
	};

	pub const MODEL_URL: &str = "https://cdn.pyke.io/0/pyke:ort-rs/example-models@0.0.0/mnist.onnx";

	pub fn get_image() -> Array4<f32> {
		let image_buffer: ImageBuffer<Luma<u8>, Vec<u8>> = image::open("tests/data/mnist_5.jpg").expect("failed to load image").to_luma8();
		ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| {
			let pixel = image_buffer.get_pixel(i as u32, j as u32);
			let channels = pixel.channels();
			(channels[c] as f32) / 255.0
		})
	}

	pub fn extract_probabilities<T: TensorValueTypeMarker>(output: &Value<T>) -> Result<Vec<(usize, f32)>> {
		let mut probabilities: Vec<(usize, f32)> = output
			.try_extract_array()?
			.softmax(Axis(1))
			.iter()
			.copied()
			.enumerate()
			.collect::<Vec<_>>();
		probabilities.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
		Ok(probabilities)
	}
}