use crate::bindings::tflite;
use core::fmt;
cpp! {{
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/kernels/all_ops_resolver.h"
}}
type OpResolverT = tflite::ops::micro::AllOpsResolver;
pub trait OpResolverRepr {
fn to_inner(self) -> OpResolverT;
}
#[derive(Default)]
pub struct AllOpResolver(OpResolverT);
impl OpResolverRepr for AllOpResolver {
fn to_inner(self) -> OpResolverT {
self.0
}
}
#[derive(Default)]
pub struct MutableOpResolver {
pub(crate) inner: OpResolverT,
capacity: usize,
len: usize,
}
impl OpResolverRepr for MutableOpResolver {
fn to_inner(self) -> OpResolverT {
self.inner
}
}
impl fmt::Debug for MutableOpResolver {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_fmt(format_args!("MutableOpResolver (ops = {})", self.len))
}
}
impl AllOpResolver {
pub fn new() -> Self {
let micro_op_resolver = unsafe {
cpp!([] -> OpResolverT as "tflite::ops::micro::AllOpsResolver" {
tflite::ops::micro::AllOpsResolver resolver;
return resolver;
})
};
Self(micro_op_resolver)
}
}
impl MutableOpResolver {
pub(crate) fn check_then_inc_len(&mut self) {
assert!(
self.len < self.capacity,
"Tensorflow micro does not support more than {} operators.",
self.capacity
);
self.len += 1;
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn empty() -> Self {
let tflite_registrations_max = 128;
let micro_op_resolver = unsafe {
cpp!([] -> OpResolverT as
"tflite::MicroMutableOpResolver<128>" {
tflite::MicroMutableOpResolver<128> resolver;
return resolver;
})
};
Self {
inner: micro_op_resolver,
capacity: tflite_registrations_max,
len: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn all_ops_resolver() {
let _ = AllOpResolver::new();
}
#[test]
fn mutable_op_resolver() {
let _ = MutableOpResolver::empty()
.depthwise_conv_2d()
.fully_connected()
.softmax();
}
}