encoderfile 0.6.2

Distribute and run transformer encoders with a single file.
Documentation
use encoderfile::dev_utils::*;
use encoderfile::inference::{
    embedding::embedding, sequence_classification::sequence_classification,
    token_classification::token_classification,
};
use encoderfile::transforms::{DEFAULT_LIBS, Transform};

#[test]
fn test_embedding_model() {
    let state = embedding_state();

    let encodings = state
        .tokenizer
        .encode_text(vec![
            "hello world".to_string(),
            "the quick brown fox jumps over the lazy dog".to_string(),
        ])
        .expect("Failed to encode text");

    let session_lock = state.session.lock();

    let transform =
        Transform::new(DEFAULT_LIBS.to_vec(), None).expect("Failed to create_transform");

    let results =
        embedding(session_lock, &transform, encodings.clone()).expect("Failed to compute results");

    assert!(results.len() == encodings.len());
}

#[test]
#[should_panic]
fn test_embedding_inference_with_bad_model() {
    let state = token_classification_state();

    let encodings = state
        .tokenizer
        .encode_text(vec![
            "hello world".to_string(),
            "the quick brown fox jumps over the lazy dog".to_string(),
        ])
        .expect("Failed to encode text");

    let session_lock = state.session.lock();

    let transform =
        Transform::new(DEFAULT_LIBS.to_vec(), None).expect("Failed to create_transform");

    embedding(session_lock, &transform, encodings.clone()).expect("Failed to compute results");
}

#[test]
fn test_sequence_classification_model() {
    let state = sequence_classification_state();

    let encodings = state
        .tokenizer
        .encode_text(vec![
            "hello world".to_string(),
            "the quick brown fox jumps over the lazy dog".to_string(),
        ])
        .expect("Failed to encode text");

    let session_lock = state.session.lock();

    let transform =
        Transform::new(DEFAULT_LIBS.to_vec(), None).expect("Failed to create_transform");

    let results = sequence_classification(
        session_lock,
        &transform,
        &state.model_config,
        encodings.clone(),
    )
    .expect("Failed to compute results");

    assert!(results.len() == encodings.len());
}

#[test]
#[should_panic]
fn test_sequence_classification_inference_with_bad_model() {
    let state = embedding_state();

    let encodings = state
        .tokenizer
        .encode_text(vec![
            "hello world".to_string(),
            "the quick brown fox jumps over the lazy dog".to_string(),
        ])
        .expect("Failed to encode text");

    let session_lock = state.session.lock();

    let transform =
        Transform::new(DEFAULT_LIBS.to_vec(), None).expect("Failed to create_transform");

    sequence_classification(
        session_lock,
        &transform,
        &state.model_config,
        encodings.clone(),
    )
    .expect("Failed to compute results");
}

#[test]
fn test_token_classification_model() {
    let state = token_classification_state();

    let encodings = state
        .tokenizer
        .encode_text(vec![
            "hello world".to_string(),
            "the quick brown fox jumps over the lazy dog".to_string(),
        ])
        .expect("Failed to encode text");

    let session_lock = state.session.lock();

    let transform =
        Transform::new(DEFAULT_LIBS.to_vec(), None).expect("Failed to create_transform");

    let results = token_classification(
        session_lock,
        &transform,
        &state.model_config,
        encodings.clone(),
    )
    .expect("Failed to compute results");

    assert!(results.len() == encodings.len());
}

#[test]
#[should_panic]
fn test_token_classification_inference_with_bad_model() {
    let state = sequence_classification_state();

    let encodings = state
        .tokenizer
        .encode_text(vec![
            "hello world".to_string(),
            "the quick brown fox jumps over the lazy dog".to_string(),
        ])
        .expect("Failed to encode text");

    let session_lock = state.session.lock();

    let transform =
        Transform::new(DEFAULT_LIBS.to_vec(), None).expect("Failed to create_transform");

    token_classification(
        session_lock,
        &transform,
        &state.model_config,
        encodings.clone(),
    )
    .expect("Failed to compute results");
}