encoderfile 0.4.0-rc.1

Distribute and run transformer encoders with a single file.
Documentation
use crate::{
    common::{ModelConfig, SequenceClassificationResult},
    error::ApiError,
    transforms::{Postprocessor, SequenceClassificationTransform},
};
use ndarray::{Array2, Axis, Ix2};
use ndarray_stats::QuantileExt;
use tokenizers::Encoding;

#[tracing::instrument(skip_all)]
pub fn sequence_classification<'a>(
    mut session: crate::runtime::Model<'a>,
    transform: &SequenceClassificationTransform,
    config: &ModelConfig,
    encodings: Vec<Encoding>,
) -> Result<Vec<SequenceClassificationResult>, ApiError> {
    let (a_ids, a_mask, a_type_ids) = crate::prepare_inputs!(encodings);

    let mut outputs = crate::run_model!(session, a_ids, a_mask, a_type_ids)?
        .get("logits")
        .expect("Model does not return logits")
        .try_extract_array::<f32>()
        .expect("Model does not return tensor extractable to f32")
        .into_dimensionality::<Ix2>()
        .expect("Model does not return tensor of shape [n_batch, n_labels]")
        .into_owned();

    outputs = transform.postprocess(outputs)?;

    let results = postprocess(outputs, config);

    Ok(results)
}

#[tracing::instrument(skip_all)]
pub fn postprocess(
    outputs: Array2<f32>,
    config: &ModelConfig,
) -> Vec<SequenceClassificationResult> {
    outputs
        .axis_iter(Axis(0))
        .zip(outputs.axis_iter(Axis(0)))
        .map(|(logs, probs)| {
            let predicted_index = probs.argmax().expect("Model has 0 labels");
            SequenceClassificationResult {
                logits: logs.to_owned().into_raw_vec_and_offset().0,
                scores: probs.to_owned().into_raw_vec_and_offset().0,
                predicted_index: (predicted_index as u32),
                predicted_label: config
                    .id2label(predicted_index as u32)
                    .map(|i| i.to_string()),
            }
        })
        .collect()
}