rust-bert 0.23.0

Ready-to-use NLP pipelines and language models
Documentation
use crate::pipelines::generation_utils::{Cache, LMModelOutput};
use crate::pipelines::onnx::common::{get_input_output_mapping, InputOutputNameMapping};
use crate::pipelines::onnx::config::{
    ONNXEnvironmentConfig, ATTENTION_MASK_NAME, ENCODER_ATTENTION_MASK_NAME,
    ENCODER_HIDDEN_STATES_NAME, INPUT_IDS_NAME, POSITION_IDS,
};
use crate::pipelines::onnx::conversion::{array_to_ort, ort_tensor_to_tch, tch_tensor_to_ndarray};
use crate::pipelines::onnx::models::ONNXLayerCache;
use crate::RustBertError;
use ort::{Environment, Session};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tch::Tensor;

pub struct ONNXDecoder {
    session: Session,
    name_mapping: InputOutputNameMapping,
    use_cache: bool,
}

impl ONNXDecoder {
    pub fn new(
        model_file: PathBuf,
        use_cache: bool,
        environment: &Arc<Environment>,
        onnx_config: &ONNXEnvironmentConfig,
    ) -> Result<Self, RustBertError> {
        let session = onnx_config
            .get_session_builder(environment)?
            .with_model_from_file(model_file)?;
        let name_mapping = get_input_output_mapping(&session);
        Ok(Self {
            session,
            name_mapping,
            use_cache,
        })
    }

    pub fn forward(
        &self,
        input_ids: Option<&Tensor>,
        attention_mask: Option<&Tensor>,
        encoder_hidden_states: Option<&Tensor>,
        encoder_attention_mask: Option<&Tensor>,
        position_ids: Option<&Tensor>,
        layer_states: Option<&ONNXLayerCache>,
    ) -> Result<LMModelOutput, RustBertError> {
        let mut input_dict = HashMap::new();
        if let Some(input_ids) = input_ids {
            input_dict.insert(INPUT_IDS_NAME, input_ids);
        }
        if let Some(attention_mask) = attention_mask {
            input_dict.insert(ATTENTION_MASK_NAME, attention_mask);
        }
        if let Some(encoder_hidden_states) = encoder_hidden_states {
            input_dict.insert(ENCODER_HIDDEN_STATES_NAME, encoder_hidden_states);
        }
        if let Some(encoder_attention_mask) = encoder_attention_mask {
            input_dict.insert(ENCODER_ATTENTION_MASK_NAME, encoder_attention_mask);
        }
        if let Some(position_ids) = position_ids {
            input_dict.insert(POSITION_IDS, position_ids);
        }

        let inputs_arrays = self
            .name_mapping
            .input_names
            .iter()
            .map(|input_name| {
                if let Some(tensor) = input_dict.remove(input_name.as_str()) {
                    tch_tensor_to_ndarray(tensor)
                } else {
                    let layer_states = layer_states.ok_or_else(|| {
                        RustBertError::OrtError(format!(
                            "{input_name} not found and cache was not provided."
                        ))
                    })?;
                    let input_pos = layer_states
                        .values
                        .get(&input_name.replace("past", "present"))
                        .or_else(|| {
                            layer_states
                                .values
                                .get(&input_name.replace("past_key_values", "present"))
                        })
                        .ok_or_else(|| {
                            let found_keys = layer_states.values.keys().collect::<Vec<&String>>();
                            RustBertError::OrtError(format!(
                                "{input_name} not found in cache ({found_keys:?})."
                            ))
                        })?;
                    tch_tensor_to_ndarray(input_pos)
                }
            })
            .collect::<Result<Vec<_>, RustBertError>>()?;

        let input_values = inputs_arrays
            .iter()
            .map(|array| array_to_ort(&self.session, array).unwrap())
            .collect::<Vec<_>>();

        let outputs = self.session.run(input_values)?;

        let lm_logits =
            ort_tensor_to_tch(&outputs[*self.name_mapping.output_names.get("logits").unwrap()])?;
        let cache = if self.use_cache {
            Cache::ONNXCache(ONNXLayerCache::from_ort_output(
                &outputs,
                &self.name_mapping.key_value_output_names,
            )?)
        } else {
            Cache::None
        };

        Ok(LMModelOutput { lm_logits, cache })
    }
}