haqumei 0.2.0

Haqumei is a Japanese Grapheme-to-Phoneme (G2P) library implemented in Rust.
Documentation
use ort::{
    session::Session,
    value::{Tensor, TensorRef, Value},
};

use crate::NjdFeature;

const ENC_MODEL_BYTES: &[u8] = include_bytes!("../yomi_model/nani_enc.onnx");
const MODEL_BYTES: &[u8] = include_bytes!("../yomi_model/nani_model.onnx");

pub struct NaniPredictor {
    enc_session: Session,
    model_session: Session,
}

impl NaniPredictor {
    pub fn new() -> ort::Result<Self> {
        let enc_session = Session::builder()?.commit_from_memory(ENC_MODEL_BYTES)?;

        let model_session = Session::builder()?.commit_from_memory(MODEL_BYTES)?;

        Ok(Self {
            enc_session,
            model_session,
        })
    }

    pub fn predict_is_nan(&mut self, prev_node: Option<&NjdFeature>) -> bool {
        match self.run_inference(prev_node) {
            Ok(prediction) => prediction == 1,
            Err(e) => {
                log::error!("Nani prediction inference failed: {}", e);
                false
            }
        }
    }

    fn run_inference(&mut self, prev_node: Option<&NjdFeature>) -> ort::Result<i64> {
        let njd = match prev_node {
            Some(node) => node,
            None => return Ok(0),
        };

        let features: [String; 6] = [
            njd.pos.to_string(),
            njd.pos_group1.to_string(),
            njd.pos_group2.to_string(),
            njd.pron.to_string(),
            njd.ctype.to_string(),
            njd.cform.to_string(),
        ];

        let shape = [1, 6];
        let tensor = Tensor::from_string_array((shape, features.as_slice()))?;
        let input_value: Value = tensor.into();
        let enc_inputs = ort::inputs!["input" => input_value];

        let enc_outputs = self.enc_session.run(enc_inputs)?;

        let (shape, data) = enc_outputs[0].try_extract_tensor::<f32>()?;

        let shape_vec: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
        let enc_tensor_ref = TensorRef::from_array_view((shape_vec, data))?;

        let model_inputs = ort::inputs!["input" => enc_tensor_ref];

        let model_outputs = self.model_session.run(model_inputs)?;
        let (_, prediction_data) = model_outputs[0].try_extract_tensor::<i64>()?;

        Ok(prediction_data.first().copied().unwrap_or(0))
    }
}