tract-kaldi 0.19.2

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use tract_hir::internal::*;

use nom::IResult;
use nom::{
    branch::alt,
    bytes::complete::*,
    character::complete::*,
    combinator::*,
    number::complete::{le_i32, le_i64},
    sequence::*,
};

use std::collections::HashMap;

use crate::model::{Component, KaldiProtoModel};

use tract_itertools::Itertools;

mod bin;
mod components;
mod config_lines;
mod descriptor;
mod text;

pub fn nnet3(slice: &[u8]) -> TractResult<KaldiProtoModel> {
    let (_, (config, components)) = parse_top_level(slice).map_err(|e| match e {
        nom::Err::Error(err) => format_err!(
            "Parsing kaldi enveloppe at: {:?}",
            err.input.iter().take(120).map(|b| format!("{}", *b as char)).join("")
        ),
        e => format_err!("{:?}", e),
    })?;
    let config_lines = config_lines::parse_config(config)?;
    Ok(KaldiProtoModel { config_lines, components, adjust_final_offset: 0 })
}

pub fn if_then_else<'a, T>(
    condition: bool,
    then: impl FnMut(&'a [u8]) -> IResult<&'a [u8], T>,
    otherwise: impl FnMut(&'a [u8]) -> IResult<&'a [u8], T>,
) -> impl FnMut(&'a [u8]) -> IResult<&'a [u8], T> {
    map(pair(cond(condition, then), cond(!condition, otherwise)), |(a, b)| a.or(b).unwrap())
}

fn parse_top_level(i: &[u8]) -> IResult<&[u8], (&str, HashMap<String, Component>)> {
    let (i, bin) = map(opt(tag([0, 0x42])), |o| Option::is_some(&o))(i)?;
    let (i, _) = open(i, "Nnet3")?;
    let (i, config_lines) = map_res(take_until("<NumComponents>"), std::str::from_utf8)(i)?;
    let (i, num_components) = num_components(bin, i)?;
    let mut components = HashMap::new();
    let mut i = i;
    for _ in 0..num_components {
        let (new_i, name) = component_name(i)?;
        debug!("Parsing component {}", name);
        let (new_i, comp) = component(bin)(new_i)?;
        i = new_i;
        components.insert(name.to_owned(), comp);
    }
    let (i, _) = close(i, "Nnet3")?;
    Ok((i, (config_lines, components)))
}

fn num_components(bin: bool, i: &[u8]) -> IResult<&[u8], usize> {
    let (i, _) = open(i, "NumComponents")?;
    let (i, n) = multispaced(integer(bin))(i)?;
    Ok((i, n as usize))
}

fn component(bin: bool) -> impl Fn(&[u8]) -> IResult<&[u8], Component> {
    move |i: &[u8]| {
        let (i, klass) = open_any(i)?;
        let (i, attributes) = if bin { bin::attributes(i, klass)? } else { text::attributes(i)? };
        let (i, _) = close(i, klass)?;
        Ok((i, Component { klass: klass.to_string(), attributes }))
    }
}

fn component_name(i: &[u8]) -> IResult<&[u8], &str> {
    multispaced(delimited(|i| open(i, "ComponentName"), name, multispace0))(i)
}

pub fn open<'a>(i: &'a [u8], t: &str) -> IResult<&'a [u8], ()> {
    map(multispaced(tuple((tag("<"), tag(t.as_bytes()), tag(">")))), |_| ())(i)
}

pub fn close<'a>(i: &'a [u8], t: &str) -> IResult<&'a [u8], ()> {
    map(multispaced(tuple((tag("</"), tag(t.as_bytes()), tag(">")))), |_| ())(i)
}

pub fn open_any(i: &[u8]) -> IResult<&[u8], &str> {
    multispaced(delimited(tag("<"), name, tag(">")))(i)
}

pub fn name(i: &[u8]) -> IResult<&[u8], &str> {
    map_res(
        recognize(pair(
            alpha1,
            nom::multi::many0(nom::branch::alt((alphanumeric1, tag("."), tag("_"), tag("-")))),
        )),
        std::str::from_utf8,
    )(i)
}

pub fn integer<'a>(bin: bool) -> impl FnMut(&'a [u8]) -> IResult<&'a [u8], i32> {
    if_then_else(
        bin,
        alt((preceded(tag([4]), le_i32), preceded(tag([8]), map(le_i64, |i| i as i32)))),
        map_res(
            map_res(
                recognize(pair(opt(tag("-")), take_while(nom::character::is_digit))),
                std::str::from_utf8,
            ),
            |s| s.parse::<i32>(),
        ),
    )
}

pub fn spaced<I, O, E: nom::error::ParseError<I>, F>(it: F) -> impl FnMut(I) -> nom::IResult<I, O, E>
where
    I: nom::InputTakeAtPosition,
    <I as nom::InputTakeAtPosition>::Item: nom::AsChar + Clone,
    F: FnMut(I) -> nom::IResult<I, O, E>,
{
    delimited(space0, it, space0)
}

pub fn multispaced<I, O, E: nom::error::ParseError<I>, F>(
    it: F,
) -> impl FnMut(I) -> nom::IResult<I, O, E>
where
    I: nom::InputTakeAtPosition,
    <I as nom::InputTakeAtPosition>::Item: nom::AsChar + Clone,
    F: FnMut(I) -> nom::IResult<I, O, E>,
{
    delimited(multispace0, it, multispace0)
}