encoderfile 0.4.0-rc.1

Distribute and run transformer encoders with a single file.
Documentation
use crate::{
    common::{
        Config, ModelConfig, TokenizerConfig,
        model_type::{self, ModelTypeSpec},
    },
    runtime::{AppState, EncoderfileState},
};
use ort::session::Session;
use parking_lot::Mutex;
use std::str::FromStr;
use std::{fs::File, io::BufReader};

const EMBEDDING_DIR: &str = "../models/embedding";
const SEQUENCE_CLASSIFICATION_DIR: &str = "../models/sequence_classification";
const TOKEN_CLASSIFICATION_DIR: &str = "../models/token_classification";

pub fn get_state<T: ModelTypeSpec>(dir: &str) -> AppState<T> {
    let config = Config {
        name: "my-model".to_string(),
        version: "0.0.1".to_string(),
        model_type: T::enum_val(),
        transform: None,
    };

    let model_config = get_model_config(dir);
    let tokenizer = get_tokenizer(dir);
    let session = get_model(dir);

    EncoderfileState::new(config, session, tokenizer, model_config).into()
}

pub fn embedding_state() -> AppState<model_type::Embedding> {
    get_state(EMBEDDING_DIR)
}

pub fn sentence_embedding_state() -> AppState<model_type::SentenceEmbedding> {
    get_state(EMBEDDING_DIR)
}

pub fn sequence_classification_state() -> AppState<model_type::SequenceClassification> {
    get_state(SEQUENCE_CLASSIFICATION_DIR)
}

pub fn token_classification_state() -> AppState<model_type::TokenClassification> {
    get_state(TOKEN_CLASSIFICATION_DIR)
}

fn get_model_config(dir: &str) -> ModelConfig {
    let file = File::open(format!("{}/{}", dir, "config.json")).expect("Config not found");
    let reader = BufReader::new(file);

    // Deserialize into struct
    serde_json::from_reader(reader).expect("Invalid model config")
}

fn get_tokenizer(dir: &str) -> crate::runtime::TokenizerService {
    let tokenizer_str = std::fs::read_to_string(format!("{}/{}", dir, "tokenizer.json"))
        .expect("Tokenizer json not found");

    get_tokenizer_from_string(tokenizer_str.as_str())
}

fn get_model(dir: &str) -> Mutex<Session> {
    Mutex::new(
        ort::session::Session::builder()
            .expect("Failed to load session")
            .commit_from_file(format!("{}/{}", dir, "model.onnx"))
            .expect("Failed to load model"),
    )
}

fn get_tokenizer_from_string(s: &str) -> crate::runtime::TokenizerService {
    let tokenizer = match tokenizers::tokenizer::Tokenizer::from_str(s) {
        Ok(t) => t,
        Err(e) => panic!("FATAL: Error loading tokenizer: {e:?}"),
    };

    crate::runtime::TokenizerService::new(tokenizer, TokenizerConfig::default())
        .expect("Error loading tokenizer")
}