microflow_macros/
lib.rs

1//! [![crates.io](https://img.shields.io/crates/v/microflow-macros)](https://crates.io/crates/microflow-macros)
2//! [![docs.rs](https://img.shields.io/docsrs/microflow-macros)](https://docs.rs/microflow-macros)
3//! [![github](https://img.shields.io/github/actions/workflow/status/matteocarnelos/microflow-rs/cargo.yml?branch=main)](https://github.com/matteocarnelos/microflow-rs/actions/workflows/cargo.yml)
4//!
5//! Macro crate of the [MicroFlow](https://github.com/matteocarnelos/microflow-rs) inference engine, namely, the MicroFlow compiler.
6
7extern crate proc_macro;
8
9use proc_macro::TokenStream;
10use proc_macro_error::{abort_call_site, proc_macro_error};
11use std::fs;
12
13use proc_macro2::TokenStream as TokenStream2;
14use quote::{quote, ToTokens};
15use syn::{parse_macro_input, ItemStruct};
16
17use crate::tflite_flatbuffers::tflite::TensorType;
18use ops::*;
19use structmeta::StructMeta;
20use syn::LitStr;
21use tflite_flatbuffers::tflite::{root_as_model, BuiltinOperator};
22
23mod activation;
24mod buffer;
25mod ops;
26mod quantize;
27mod tensor;
28#[path = "../flatbuffers/tflite_generated.rs"]
29#[allow(unused_imports)]
30#[allow(clippy::all)]
31mod tflite_flatbuffers;
32
33#[derive(StructMeta)]
34struct Args {
35    #[struct_meta(unnamed)]
36    path: LitStr,
37}
38
39/// The entry point of MicroFlow.
40/// This attribute-like procedural macro can be placed on `structs` to implement the `predict()`
41/// function based on the given model.
42/// The macro takes as input the path of the model, which must be in the TensorFlow Lite format
43/// (`.tflite`).
44#[proc_macro_error]
45#[proc_macro_attribute]
46pub fn model(args: TokenStream, item: TokenStream) -> TokenStream {
47    let args = parse_macro_input!(args as Args);
48    let item = parse_macro_input!(item as ItemStruct);
49
50    let buf = fs::read(args.path.value()).unwrap_or_else(|_| {
51        abort_call_site!(
52            "couldn't find '{}', please provide a valid path",
53            &args.path.value()
54        )
55    });
56    let model = root_as_model(&buf).unwrap_or_else(|_| {
57        abort_call_site!("invalid model, please provide a valid TensorFlow Lite model")
58    });
59
60    let ident = &item.ident;
61
62    let subgraph = model.subgraphs().unwrap().get(0);
63    let tensors = subgraph.tensors().unwrap();
64    let buffers = model.buffers().unwrap();
65
66    let input = tensors.get(subgraph.inputs().unwrap().get(0) as usize);
67    let mut input_shape: Vec<_> = input.shape().unwrap().iter().map(|e| e as usize).collect();
68    if input_shape.len() == 1 {
69        input_shape.insert(0, 1);
70    }
71    let input_type = match input.type_() {
72        TensorType::INT8 => quote!(i8),
73        TensorType::UINT8 => quote!(u8),
74        _ => unimplemented!(),
75    };
76    let input_tensor = match input_shape.len() {
77        2 => quote!(Tensor2D),
78        4 => quote!(Tensor4D),
79        _ => unimplemented!(),
80    };
81    let input_buffer = match input_shape.len() {
82        2 => quote!(Buffer2D),
83        4 => quote!(Buffer4D),
84        _ => unimplemented!(),
85    };
86    let input_scale: Vec<_> = input
87        .quantization()
88        .unwrap()
89        .scale()
90        .unwrap()
91        .iter()
92        .map(|e| e.to_token_stream())
93        .collect();
94    let input_zero_point: Vec<_> = match input.type_() {
95        TensorType::INT8 => input
96            .quantization()
97            .unwrap()
98            .zero_point()
99            .unwrap()
100            .iter()
101            .map(|e| (e as i8).to_token_stream())
102            .collect(),
103        TensorType::UINT8 => input
104            .quantization()
105            .unwrap()
106            .zero_point()
107            .unwrap()
108            .iter()
109            .map(|e| (e as u8).to_token_stream())
110            .collect(),
111        _ => unimplemented!(),
112    };
113
114    let operators = subgraph.operators().unwrap();
115    let mut layers = TokenStream2::new();
116    for (index, operator) in operators.iter().enumerate() {
117        let layer: Box<dyn ToTokens> = match BuiltinOperator(
118            model
119                .operator_codes()
120                .unwrap()
121                .get(operator.opcode_index() as usize)
122                .deprecated_builtin_code() as i32,
123        ) {
124            BuiltinOperator::FULLY_CONNECTED => {
125                fully_connected::parse(operator, tensors, buffers, index)
126            }
127            BuiltinOperator::DEPTHWISE_CONV_2D => {
128                depthwise_conv_2d::parse(operator, tensors, buffers, index)
129            }
130            BuiltinOperator::CONV_2D => conv_2d::parse(operator, tensors, buffers, index),
131            BuiltinOperator::AVERAGE_POOL_2D => average_pool_2d::parse(operator, tensors),
132            BuiltinOperator::SOFTMAX => softmax::parse(operator, tensors),
133            BuiltinOperator::RESHAPE => Box::new(reshape::parse(operator, tensors)),
134            unsupported_op => abort_call_site!("unsupported operator: {:?}", unsupported_op),
135        };
136        layer.to_tokens(&mut layers)
137    }
138
139    let output = tensors.get(subgraph.outputs().unwrap().get(0) as usize);
140    let mut output_shape: Vec<_> = output.shape().unwrap().iter().map(|e| e as usize).collect();
141    if output_shape.len() == 1 {
142        output_shape.insert(0, 1);
143    }
144    let output_type = match output.type_() {
145        TensorType::INT8 => quote!(i8),
146        TensorType::UINT8 => quote!(u8),
147        _ => unimplemented!(),
148    };
149    let output_tensor = match output_shape.len() {
150        2 => quote!(Tensor2D),
151        4 => quote!(Tensor4D),
152        _ => unimplemented!(),
153    };
154    let output_buffer = match output_shape.len() {
155        2 => quote!(Buffer2D),
156        4 => quote!(Buffer4D),
157        _ => unimplemented!(),
158    };
159
160    let ts = quote! {
161        #item
162        impl #ident {
163            pub fn predict(input: microflow::buffer::#input_buffer<f32, #(#input_shape),*>) -> microflow::buffer::#output_buffer<f32, #(#output_shape),*> {
164                let input = microflow::tensor::#input_tensor::quantize(input, [#(#input_scale),*], [#(#input_zero_point),*]);
165                Self::predict_inner(input).dequantize()
166            }
167
168            pub fn predict_quantized(input: microflow::buffer::#input_buffer<#input_type, #(#input_shape),*>) -> microflow::buffer::#output_buffer<f32, #(#output_shape),*> {
169                let input = microflow::tensor::#input_tensor::new(input, [#(#input_scale),*], [#(#input_zero_point),*]);
170                Self::predict_inner(input).dequantize()
171            }
172
173            fn predict_inner(input: microflow::tensor::#input_tensor<#input_type, #(#input_shape),*, 1usize>) -> microflow::tensor::#output_tensor<#output_type, #(#output_shape),*, 1usize> {
174                #layers
175                input
176            }
177        }
178    };
179
180    fs::write("target/microflow-expansion.rs", ts.to_string()).ok();
181
182    ts.into()
183}