docs.rs failed to build tflite-0.9.8
Please check the
build logs for more information.
See
Builds for ideas on how to fix a failed build,
or
Metadata for how to configure docs.rs builds.
If you believe this is docs.rs' fault,
open an issue.
Rust bindings for TensorFlow Lite
This crates provides TensorFlow Lite APIs.
Please read the API documentation on docs.rs
Using the interpreter from a model file
The following example shows how to use the TensorFlow Lite interpreter when provided a TensorFlow Lite FlatBuffer file.
The example also demonstrates how to run inference on input data.
use std::fs::{self, File};
use std::io::Read;
use tflite::ops::builtin::BuiltinOpResolver;
use tflite::{FlatBufferModel, InterpreterBuilder, Result};
fn test_mnist(model: &FlatBufferModel) -> Result<()> {
let resolver = BuiltinOpResolver::default();
let builder = InterpreterBuilder::new(model, &resolver)?;
let mut interpreter = builder.build()?;
interpreter.allocate_tensors()?;
let inputs = interpreter.inputs().to_vec();
assert_eq!(inputs.len(), 1);
let input_index = inputs[0];
let outputs = interpreter.outputs().to_vec();
assert_eq!(outputs.len(), 1);
let output_index = outputs[0];
let input_tensor = interpreter.tensor_info(input_index).unwrap();
assert_eq!(input_tensor.dims, vec![1, 28, 28, 1]);
let output_tensor = interpreter.tensor_info(output_index).unwrap();
assert_eq!(output_tensor.dims, vec![1, 10]);
let mut input_file = File::open("data/mnist10.bin")?;
for i in 0..10 {
input_file.read_exact(interpreter.tensor_data_mut(input_index)?)?;
interpreter.invoke()?;
let output: &[u8] = interpreter.tensor_data(output_index)?;
let guess = output.iter().enumerate().max_by(|x, y| x.1.cmp(y.1)).unwrap().0;
println!("{}: {:?}", i, output);
assert_eq!(i, guess);
}
Ok(())
}
#[test]
fn mobilenetv1_mnist() -> Result<()> {
test_mnist(&FlatBufferModel::build_from_file("data/MNISTnet_uint8_quant.tflite")?)?;
let buf = fs::read("data/MNISTnet_uint8_quant.tflite")?;
test_mnist(&FlatBufferModel::build_from_buffer(buf)?)
}
#[test]
fn mobilenetv2_mnist() -> Result<()> {
test_mnist(&FlatBufferModel::build_from_file("data/MNISTnet_v2_uint8_quant.tflite")?)?;
let buf = fs::read("data/MNISTnet_v2_uint8_quant.tflite")?;
test_mnist(&FlatBufferModel::build_from_buffer(buf)?)
}
Using the FlatBuffers model APIs
This crate also provides a limited set of FlatBuffers model APIs.
use tflite::model::stl::vector::{VectorInsert, VectorErase, VectorSlice};
use tflite::model::{BuiltinOperator, BuiltinOptions, Model, SoftmaxOptionsT};
#[test]
fn flatbuffer_model_apis_inspect() {
let model = Model::from_file("data/MNISTnet_uint8_quant.tflite").unwrap();
assert_eq!(model.version, 3);
assert_eq!(model.operator_codes.size(), 5);
assert_eq!(model.subgraphs.size(), 1);
assert_eq!(model.buffers.size(), 24);
assert_eq!(
model.description.c_str().to_string_lossy(),
"TOCO Converted."
);
assert_eq!(
model.operator_codes[0].builtin_code,
BuiltinOperator::BuiltinOperator_AVERAGE_POOL_2D
);
assert_eq!(
model
.operator_codes
.iter()
.map(|oc| oc.builtin_code)
.collect::<Vec<_>>(),
vec![
BuiltinOperator::BuiltinOperator_AVERAGE_POOL_2D,
BuiltinOperator::BuiltinOperator_CONV_2D,
BuiltinOperator::BuiltinOperator_DEPTHWISE_CONV_2D,
BuiltinOperator::BuiltinOperator_SOFTMAX,
BuiltinOperator::BuiltinOperator_RESHAPE
]
);
let subgraph = &model.subgraphs[0];
assert_eq!(subgraph.tensors.size(), 23);
assert_eq!(subgraph.operators.size(), 9);
assert_eq!(subgraph.inputs.as_slice(), &[22]);
assert_eq!(subgraph.outputs.as_slice(), &[21]);
let softmax = subgraph
.operators
.iter()
.position(|op| {
model.operator_codes[op.opcode_index as usize].builtin_code
== BuiltinOperator::BuiltinOperator_SOFTMAX
})
.unwrap();
assert_eq!(subgraph.operators[softmax].inputs.as_slice(), &[4]);
assert_eq!(subgraph.operators[softmax].outputs.as_slice(), &[21]);
assert_eq!(
subgraph.operators[softmax].builtin_options.type_,
BuiltinOptions::BuiltinOptions_SoftmaxOptions
);
let softmax_options: &SoftmaxOptionsT = subgraph.operators[softmax].builtin_options.as_ref();
assert_eq!(softmax_options.beta, 1.);
}
#[test]
fn flatbuffer_model_apis_mutate() {
let mut model = Model::from_file("data/MNISTnet_uint8_quant.tflite").unwrap();
model.version = 2;
model.operator_codes.erase(4);
model.buffers.erase(22);
model.buffers.erase(23);
model
.description
.assign(CString::new("flatbuffer").unwrap());
{
let subgraph = &mut model.subgraphs[0];
subgraph.inputs.erase(0);
subgraph.outputs.assign(vec![1, 2, 3, 4]);
}
let model_buffer = model.to_buffer();
let model = Model::from_buffer(&model_buffer);
assert_eq!(model.version, 2);
assert_eq!(model.operator_codes.size(), 4);
assert_eq!(model.subgraphs.size(), 1);
assert_eq!(model.buffers.size(), 22);
assert_eq!(model.description.c_str().to_string_lossy(), "flatbuffer");
let subgraph = &model.subgraphs[0];
assert_eq!(subgraph.tensors.size(), 23);
assert_eq!(subgraph.operators.size(), 9);
assert!(subgraph.inputs.as_slice().is_empty());
assert_eq!(subgraph.outputs.as_slice(), &[1, 2, 3, 4]);
}