encoderfile 0.1.0-rc.17

Distribute and run transformer encoders with a single file.
Documentation
use anyhow::{Result, bail};
use ort::{
    session::{Output, Session},
    tensor::Shape,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::path::Path;

#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum ModelType {
    Embedding,
    SequenceClassification,
    TokenClassification,
    SentenceEmbedding,
}

impl ModelType {
    pub fn validate_model(&self, path: &Path) -> Result<()> {
        let model = load_model(path)?;

        match self {
            Self::Embedding => validate_embedding_model(model),
            Self::SequenceClassification => validate_sequence_classification_model(model),
            Self::TokenClassification => validate_token_classification_model(model),
            Self::SentenceEmbedding => validate_embedding_model(model),
        }
    }
}

fn validate_embedding_model(model: Session) -> Result<()> {
    let shape = get_outp_dim(model.outputs.as_slice(), "last_hidden_state")?;

    if shape.len() != 3 {
        bail!("Model must return tensor of shape [batch_size, seq_len, hidden_dim]")
    }

    Ok(())
}

fn validate_sequence_classification_model(model: Session) -> Result<()> {
    let shape = get_outp_dim(model.outputs.as_slice(), "logits")?;

    if shape.len() != 2 {
        bail!("Model must return tensor of shape [batch_size, n_labels]")
    }

    Ok(())
}

fn validate_token_classification_model(model: Session) -> Result<()> {
    let shape = get_outp_dim(model.outputs.as_slice(), "logits")?;

    if shape.len() != 3 {
        bail!("Model must return tensor of shape [batch_size, seq_len, n_labels]")
    }

    Ok(())
}

fn get_outp_dim<'a>(outputs: &'a [Output], outp_name: &str) -> Result<&'a Shape> {
    outputs
        .iter()
        .find(|i| i.name == outp_name)
        .ok_or(anyhow::anyhow!(format!("Model must return {}", outp_name)))?
        .output_type
        .tensor_shape()
        .ok_or(anyhow::anyhow!("Model must return tensor"))
}

fn load_model(file: &Path) -> Result<Session> {
    Ok(Session::builder()?.commit_from_file(file)?)
}