Skip to main content

tract_tflite/
model.rs

1use std::collections::hash_map::Entry;
2use std::fmt::Debug;
3
4use flatbuffers::FlatBufferBuilder;
5use tract_core::internal::*;
6
7use crate::registry::Registry;
8use crate::tensors::{flat_tensor_to_tract_fact, flat_tensor_uses_per_axis_q};
9use crate::tflite;
10use crate::tflite::{Buffer, BufferArgs};
11
12pub struct Tflite(Registry);
13
14impl Debug for Tflite {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        write!(f, "tract-TfLite-framework")
17    }
18}
19
20impl Default for Tflite {
21    fn default() -> Self {
22        let mut registry = Registry::default();
23        crate::ops::register_all(&mut registry);
24        Tflite(registry)
25    }
26}
27
28#[derive(Clone, Debug)]
29pub struct TfliteProtoModel(Vec<u8>);
30
31impl TfliteProtoModel {
32    fn new(buf: Vec<u8>) -> TractResult<TfliteProtoModel> {
33        let _ = tflite::root_as_model(&buf)?;
34        Ok(TfliteProtoModel(buf))
35    }
36
37    pub fn root(&self) -> tflite::Model<'_> {
38        unsafe { tflite::root_as_model_unchecked(&self.0) }
39    }
40}
41
42fn write_model<'fb>(
43    registry: &Registry,
44    model: &TypedModel,
45) -> TractResult<FlatBufferBuilder<'fb>> {
46    let mut model = model.clone();
47    crate::rewriter::rewrite_for_tflite(&mut model).context("Pre-dump rewrite")?;
48    let mut builder = flatbuffers::FlatBufferBuilder::new();
49    let mut op_codes = vec![];
50    let sentinel = Buffer::create(&mut builder, &BufferArgs { data: None });
51    let mut buffers = vec![sentinel];
52    crate::ser::ModelBuilder {
53        registry,
54        builder: &mut builder,
55        op_codes: &mut op_codes,
56        buffers: &mut buffers,
57    }
58    .write_model(&model)?;
59    Ok(builder)
60}
61
62impl Tflite {
63    pub fn write(&self, model: &TypedModel, mut w: impl std::io::Write) -> TractResult<()> {
64        let builder = write_model(&self.0, model)?;
65        w.write_all(builder.finished_data())?;
66        Ok(())
67    }
68}
69
70impl Framework<TfliteProtoModel, TypedModel> for Tflite {
71    fn proto_model_for_read(
72        &self,
73        reader: &mut dyn std::io::Read,
74    ) -> tract_core::prelude::TractResult<TfliteProtoModel> {
75        let mut buf = vec![];
76        reader.read_to_end(&mut buf)?;
77        TfliteProtoModel::new(buf)
78    }
79
80    fn model_for_proto_model_with_model_template(
81        &self,
82        proto: &TfliteProtoModel,
83        mut target: TypedModel,
84    ) -> TractResult<TypedModel> {
85        let root = proto.root();
86        let main = &root.subgraphs().context("No subgraphs in Tflite model")?.get(0);
87        let mut mapping = HashMap::new();
88        for input in main.inputs().context("No inputs in Tflite model")? {
89            if !flat_tensor_uses_per_axis_q(main, input) {
90                let (fact, name) = flat_tensor_to_tract_fact(&root, main, input)?;
91                let it = target.add_source(name, fact)?;
92                mapping.insert(input, it);
93            }
94        }
95        for op in main.operators().context("No operators in Tflite model")? {
96            for input in op.inputs().context("No input in Tflite  operator")? {
97                if let Entry::Vacant(slot) = mapping.entry(input) {
98                    let (fact, name) = flat_tensor_to_tract_fact(&root, main, input)?;
99                    let value = fact.konst.with_context(|| format!("Error in TF file for operator {op:?}. No prior computation nor constant for input {input}"))?;
100                    let konst = target.add_const(name, value)?;
101                    slot.insert(konst);
102                }
103            }
104            self.0.deser_op(&root, main, &op, &mut target, &mut mapping).with_context(|| {
105                format!("Translating proto-op from Tflite into tract op: {op:#?}")
106            })?;
107        }
108        let outputs: TVec<_> = main
109            .outputs()
110            .context("No outputs in Tflite model")?
111            .iter()
112            .map(|o| mapping[&o])
113            .collect();
114        target.set_output_outlets(&outputs)?;
115        Ok(target)
116    }
117}