tract_kaldi/
model.rs

1use std::collections::BTreeMap;
2
3use tract_hir::internal::*;
4
5#[derive(Clone, Debug)]
6pub struct KaldiProtoModel {
7    pub config_lines: ConfigLines,
8    pub components: HashMap<String, Component>,
9    pub adjust_final_offset: isize,
10}
11
12#[derive(Clone, Debug)]
13pub struct ConfigLines {
14    pub input_name: String,
15    pub input_dim: usize,
16    pub nodes: Vec<(String, NodeLine)>,
17    pub outputs: Vec<OutputLine>,
18}
19
20#[derive(Clone, Debug)]
21pub enum NodeLine {
22    Component(ComponentNode),
23    DimRange(DimRangeNode),
24}
25
26#[derive(Clone, Debug)]
27pub struct OutputLine {
28    pub output_alias: String,
29    pub descriptor: GeneralDescriptor,
30}
31
32#[derive(Clone, Debug, PartialEq)]
33pub enum GeneralDescriptor {
34    Append(Vec<GeneralDescriptor>),
35    IfDefined(Box<GeneralDescriptor>),
36    Name(String),
37    Offset(Box<GeneralDescriptor>, isize),
38}
39
40impl GeneralDescriptor {
41    pub fn inputs(&self) -> TVec<&str> {
42        match self {
43            GeneralDescriptor::Append(ref gds) => gds.iter().fold(tvec!(), |mut acc, gd| {
44                gd.inputs().iter().for_each(|i| {
45                    if !acc.contains(i) {
46                        acc.push(i)
47                    }
48                });
49                acc
50            }),
51            GeneralDescriptor::IfDefined(ref gd) => gd.inputs(),
52            GeneralDescriptor::Name(ref s) => tvec!(&**s),
53            GeneralDescriptor::Offset(ref gd, _) => gd.inputs(),
54        }
55    }
56
57    pub fn as_conv_shape_dilation(&self) -> Option<(usize, usize)> {
58        if let GeneralDescriptor::Name(_) = self {
59            return Some((1, 1));
60        }
61        if let GeneralDescriptor::Append(ref appendees) = self {
62            let mut offsets = vec![];
63            for app in appendees {
64                match app {
65                    GeneralDescriptor::Name(_) => offsets.push(0),
66                    GeneralDescriptor::Offset(_, offset) => offsets.push(*offset),
67                    _ => return None,
68                }
69            }
70            let dilation = offsets[1] - offsets[0];
71            if offsets.windows(2).all(|pair| pair[1] - pair[0] == dilation) {
72                return Some((offsets.len(), dilation as usize));
73            }
74        }
75        None
76    }
77
78    fn wire(
79        &self,
80        inlet: InletId,
81        name: &str,
82        model: &mut InferenceModel,
83        deferred: &mut BTreeMap<InletId, String>,
84        adjust_final_offset: Option<isize>,
85    ) -> TractResult<()> {
86        use GeneralDescriptor::*;
87        match self {
88            Name(n) => {
89                deferred.insert(inlet, n.to_string());
90                return Ok(());
91            }
92            Append(appendees) => {
93                let name = format!("{name}.Append");
94                let id = model.add_node(
95                    &*name,
96                    expand(tract_hir::ops::array::Concat::new(1)),
97                    tvec!(InferenceFact::default()),
98                )?;
99                model.add_edge(OutletId::new(id, 0), inlet)?;
100                for (ix, appendee) in appendees.iter().enumerate() {
101                    let name = format!("{name}-{ix}");
102                    appendee.wire(
103                        InletId::new(id, ix),
104                        &name,
105                        model,
106                        deferred,
107                        adjust_final_offset,
108                    )?;
109                }
110                return Ok(());
111            }
112            IfDefined(ref o) => {
113                if let Offset(n, o) = &**o {
114                    if let Name(n) = &**n {
115                        let name = format!("{name}.memory");
116                        model.add_node(
117                            &*name,
118                            crate::ops::memory::Memory::new(n.to_string(), *o),
119                            tvec!(InferenceFact::default()),
120                        )?;
121                        deferred.insert(inlet, name);
122                        return Ok(());
123                    }
124                }
125            }
126            Offset(ref n, o) if *o > 0 => {
127                let name = format!("{name}-Delay");
128                let crop = *o + adjust_final_offset.unwrap_or(0);
129                if crop < 0 {
130                    bail!("Invalid offset adjustment (network as {}, adjustment is {}", o, crop)
131                }
132                let id = model.add_node(
133                    &*name,
134                    expand(tract_hir::ops::array::Crop::new(0, crop as usize, 0)),
135                    tvec!(InferenceFact::default()),
136                )?;
137                model.add_edge(OutletId::new(id, 0), inlet)?;
138                n.wire(InletId::new(id, 0), &name, model, deferred, adjust_final_offset)?;
139                return Ok(());
140            }
141            _ => (),
142        }
143        bail!("Unhandled input descriptor: {:?}", self)
144    }
145}
146
147#[derive(Clone, Debug)]
148pub struct DimRangeNode {
149    pub input: GeneralDescriptor,
150    pub offset: usize,
151    pub dim: usize,
152}
153
154#[derive(Clone, Debug)]
155pub struct ComponentNode {
156    pub input: GeneralDescriptor,
157    pub component: String,
158}
159
160#[derive(Clone, Debug, Default)]
161pub struct Component {
162    pub klass: String,
163    pub attributes: HashMap<String, Arc<Tensor>>,
164}
165
166pub struct ParsingContext<'a> {
167    pub proto_model: &'a KaldiProtoModel,
168}
169
170type OpBuilder = fn(&ParsingContext, node: &str) -> TractResult<Box<dyn InferenceOp>>;
171
172#[derive(Clone, Default)]
173pub struct KaldiOpRegister(pub HashMap<String, OpBuilder>);
174
175impl KaldiOpRegister {
176    pub fn insert(&mut self, s: &'static str, builder: OpBuilder) {
177        self.0.insert(s.into(), builder);
178    }
179}
180
181#[derive(Clone, Default)]
182pub struct Kaldi {
183    pub op_register: KaldiOpRegister,
184}
185
186impl Framework<KaldiProtoModel, InferenceModel> for Kaldi {
187    fn proto_model_for_read(&self, r: &mut dyn std::io::Read) -> TractResult<KaldiProtoModel> {
188        use crate::parser;
189        let mut v = vec![];
190        r.read_to_end(&mut v)?;
191        parser::nnet3(&v)
192    }
193
194    fn model_for_proto_model_with_symbols(
195        &self,
196        proto_model: &KaldiProtoModel,
197        symbols: &SymbolTable,
198    ) -> TractResult<InferenceModel> {
199        let ctx = ParsingContext { proto_model };
200        let mut model =
201            InferenceModel { symbol_table: symbols.to_owned(), ..InferenceModel::default() };
202
203        let s = model.symbol_table.sym("S");
204        model.add_source(
205            proto_model.config_lines.input_name.clone(),
206            f32::fact(dims!(s, proto_model.config_lines.input_dim)).into(),
207        )?;
208        let mut inputs_to_wire: BTreeMap<InletId, String> = Default::default();
209        for (name, node) in &proto_model.config_lines.nodes {
210            match node {
211                NodeLine::Component(line) => {
212                    let component = &proto_model.components[&line.component];
213                    if crate::ops::AFFINE.contains(&&*component.klass)
214                        && line.input.as_conv_shape_dilation().is_some()
215                    {
216                        let op = crate::ops::affine::affine_component(&ctx, name)?;
217                        let id = model.add_node(
218                            name.to_string(),
219                            op,
220                            tvec!(InferenceFact::default()),
221                        )?;
222                        inputs_to_wire
223                            .insert(InletId::new(id, 0), line.input.inputs()[0].to_owned());
224                    } else {
225                        let op = match self.op_register.0.get(&*component.klass) {
226                            Some(builder) => (builder)(&ctx, name)?,
227                            None => Box::new(tract_hir::ops::unimpl::UnimplementedOp::new(
228                                1,
229                                &component.klass,
230                                format!("{line:?}"),
231                            )),
232                        };
233                        let id = model.add_node(
234                            name.to_string(),
235                            op,
236                            tvec!(InferenceFact::default()),
237                        )?;
238                        line.input.wire(
239                            InletId::new(id, 0),
240                            name,
241                            &mut model,
242                            &mut inputs_to_wire,
243                            None,
244                        )?
245                    }
246                }
247                NodeLine::DimRange(line) => {
248                    let op =
249                        tract_hir::ops::array::Slice::new(1, line.offset, line.offset + line.dim);
250                    let id =
251                        model.add_node(name.to_string(), op, tvec!(InferenceFact::default()))?;
252                    line.input.wire(
253                        InletId::new(id, 0),
254                        name,
255                        &mut model,
256                        &mut inputs_to_wire,
257                        None,
258                    )?
259                }
260            }
261        }
262        let mut outputs = vec![];
263        for o in &proto_model.config_lines.outputs {
264            let output = model.add_node(
265                &*o.output_alias,
266                tract_hir::ops::identity::Identity::default(),
267                tvec!(InferenceFact::default()),
268            )?;
269            model.set_outlet_label(output.into(), o.output_alias.to_string())?;
270            o.descriptor.wire(
271                InletId::new(output, 0),
272                "output",
273                &mut model,
274                &mut inputs_to_wire,
275                Some(proto_model.adjust_final_offset),
276            )?;
277            outputs.push(OutletId::new(output, 0));
278        }
279        for (inlet, name) in inputs_to_wire {
280            let src = OutletId::new(model.node_by_name(&*name)?.id, 0);
281            model.add_edge(src, inlet)?;
282        }
283        model.set_output_outlets(&outputs)?;
284        Ok(model)
285    }
286}