1extern crate proc_macro;
8
9use proc_macro::TokenStream;
10use proc_macro_error::{abort_call_site, proc_macro_error};
11use std::fs;
12
13use proc_macro2::TokenStream as TokenStream2;
14use quote::{quote, ToTokens};
15use syn::{parse_macro_input, ItemStruct};
16
17use crate::tflite_flatbuffers::tflite::TensorType;
18use ops::*;
19use structmeta::StructMeta;
20use syn::LitStr;
21use tflite_flatbuffers::tflite::{root_as_model, BuiltinOperator};
22
23mod activation;
24mod buffer;
25mod ops;
26mod quantize;
27mod tensor;
28#[path = "../flatbuffers/tflite_generated.rs"]
29#[allow(unused_imports)]
30#[allow(clippy::all)]
31mod tflite_flatbuffers;
32
33#[derive(StructMeta)]
34struct Args {
35 #[struct_meta(unnamed)]
36 path: LitStr,
37}
38
39#[proc_macro_error]
45#[proc_macro_attribute]
46pub fn model(args: TokenStream, item: TokenStream) -> TokenStream {
47 let args = parse_macro_input!(args as Args);
48 let item = parse_macro_input!(item as ItemStruct);
49
50 let buf = fs::read(args.path.value()).unwrap_or_else(|_| {
51 abort_call_site!(
52 "couldn't find '{}', please provide a valid path",
53 &args.path.value()
54 )
55 });
56 let model = root_as_model(&buf).unwrap_or_else(|_| {
57 abort_call_site!("invalid model, please provide a valid TensorFlow Lite model")
58 });
59
60 let ident = &item.ident;
61
62 let subgraph = model.subgraphs().unwrap().get(0);
63 let tensors = subgraph.tensors().unwrap();
64 let buffers = model.buffers().unwrap();
65
66 let input = tensors.get(subgraph.inputs().unwrap().get(0) as usize);
67 let mut input_shape: Vec<_> = input.shape().unwrap().iter().map(|e| e as usize).collect();
68 if input_shape.len() == 1 {
69 input_shape.insert(0, 1);
70 }
71 let input_type = match input.type_() {
72 TensorType::INT8 => quote!(i8),
73 TensorType::UINT8 => quote!(u8),
74 _ => unimplemented!(),
75 };
76 let input_tensor = match input_shape.len() {
77 2 => quote!(Tensor2D),
78 4 => quote!(Tensor4D),
79 _ => unimplemented!(),
80 };
81 let input_buffer = match input_shape.len() {
82 2 => quote!(Buffer2D),
83 4 => quote!(Buffer4D),
84 _ => unimplemented!(),
85 };
86 let input_scale: Vec<_> = input
87 .quantization()
88 .unwrap()
89 .scale()
90 .unwrap()
91 .iter()
92 .map(|e| e.to_token_stream())
93 .collect();
94 let input_zero_point: Vec<_> = match input.type_() {
95 TensorType::INT8 => input
96 .quantization()
97 .unwrap()
98 .zero_point()
99 .unwrap()
100 .iter()
101 .map(|e| (e as i8).to_token_stream())
102 .collect(),
103 TensorType::UINT8 => input
104 .quantization()
105 .unwrap()
106 .zero_point()
107 .unwrap()
108 .iter()
109 .map(|e| (e as u8).to_token_stream())
110 .collect(),
111 _ => unimplemented!(),
112 };
113
114 let operators = subgraph.operators().unwrap();
115 let mut layers = TokenStream2::new();
116 for (index, operator) in operators.iter().enumerate() {
117 let layer: Box<dyn ToTokens> = match BuiltinOperator(
118 model
119 .operator_codes()
120 .unwrap()
121 .get(operator.opcode_index() as usize)
122 .deprecated_builtin_code() as i32,
123 ) {
124 BuiltinOperator::FULLY_CONNECTED => {
125 fully_connected::parse(operator, tensors, buffers, index)
126 }
127 BuiltinOperator::DEPTHWISE_CONV_2D => {
128 depthwise_conv_2d::parse(operator, tensors, buffers, index)
129 }
130 BuiltinOperator::CONV_2D => conv_2d::parse(operator, tensors, buffers, index),
131 BuiltinOperator::AVERAGE_POOL_2D => average_pool_2d::parse(operator, tensors),
132 BuiltinOperator::SOFTMAX => softmax::parse(operator, tensors),
133 BuiltinOperator::RESHAPE => Box::new(reshape::parse(operator, tensors)),
134 unsupported_op => abort_call_site!("unsupported operator: {:?}", unsupported_op),
135 };
136 layer.to_tokens(&mut layers)
137 }
138
139 let output = tensors.get(subgraph.outputs().unwrap().get(0) as usize);
140 let mut output_shape: Vec<_> = output.shape().unwrap().iter().map(|e| e as usize).collect();
141 if output_shape.len() == 1 {
142 output_shape.insert(0, 1);
143 }
144 let output_type = match output.type_() {
145 TensorType::INT8 => quote!(i8),
146 TensorType::UINT8 => quote!(u8),
147 _ => unimplemented!(),
148 };
149 let output_tensor = match output_shape.len() {
150 2 => quote!(Tensor2D),
151 4 => quote!(Tensor4D),
152 _ => unimplemented!(),
153 };
154 let output_buffer = match output_shape.len() {
155 2 => quote!(Buffer2D),
156 4 => quote!(Buffer4D),
157 _ => unimplemented!(),
158 };
159
160 let ts = quote! {
161 #item
162 impl #ident {
163 pub fn predict(input: microflow::buffer::#input_buffer<f32, #(#input_shape),*>) -> microflow::buffer::#output_buffer<f32, #(#output_shape),*> {
164 let input = microflow::tensor::#input_tensor::quantize(input, [#(#input_scale),*], [#(#input_zero_point),*]);
165 Self::predict_inner(input).dequantize()
166 }
167
168 pub fn predict_quantized(input: microflow::buffer::#input_buffer<#input_type, #(#input_shape),*>) -> microflow::buffer::#output_buffer<f32, #(#output_shape),*> {
169 let input = microflow::tensor::#input_tensor::new(input, [#(#input_scale),*], [#(#input_zero_point),*]);
170 Self::predict_inner(input).dequantize()
171 }
172
173 fn predict_inner(input: microflow::tensor::#input_tensor<#input_type, #(#input_shape),*, 1usize>) -> microflow::tensor::#output_tensor<#output_type, #(#output_shape),*, 1usize> {
174 #layers
175 input
176 }
177 }
178 };
179
180 fs::write("target/microflow-expansion.rs", ts.to_string()).ok();
181
182 ts.into()
183}