use ort2::prelude::*;
let model = include_bytes!("models/mnist-8.onnx");
let session = Session::builder()
.build(model.as_ref())
.expect("failed to create session");
let input = vec![0.0f32;28 * 28];
let value = Value::tensor()
.with_shape([1, 1, 28, 28])
.with_typ(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
.borrow(&input)
.expect("failed to build value");
let output = session.run([&value])
.expect("failed to run")
.into_iter()
.next()
.expect("failed to get outputs");
let output = output
.view::<f32>()
.expect("failed to view output");
assert_eq!(output.shape()[1], 10);
use ort2::prelude::*;
let model = include_bytes!("models/mnist-8.onnx");
let session = Session::builder()
.build(model.as_ref())
.expect("failed to create session");
let input = vec![0.0f32;28 * 28];
let value = Value::tensor()
.with_shape([1, 1, 28, 28])
.with_typ(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
.borrow(&input)
.expect("failed to build value");
let mut iobinding = session.iobinding()
.expect("failed to create iobinding");
iobinding.bind_input(
&session.get_inputs()
.expect("failed to get input")[0].name,
&value
)
.expect("failed to bind input");
let mem_info = MemoryInfo::default();
iobinding.bind_output_to_device(
&session.get_outputs()
.expect("failed to get outputs")[0].name,
&mem_info
)
.expect("failed to bind outputs");
session.run_with_iobinding(&mut iobinding)
.expect("failed to run");
let alloc = DefaultAllocator::default();
let output = iobinding
.get_bound_outputs(&alloc)
.expect("failed to get output from iobinding")
.into_iter()
.next()
.expect("failed to get output");
let output = output
.view::<f32>()
.expect("failed to view output");
assert_eq!(output.shape()[1], 10);