rust-bert 0.23.0

Ready-to-use NLP pipelines and language models
Documentation
use crate::RustBertError;
use tch::nn::Embedding;
use tch::{Device, Tensor};

pub fn process_ids_embeddings_pair(
    input_ids: Option<&Tensor>,
    input_embeddings: Option<&Tensor>,
    embeddings_matrix: &Embedding,
) -> Result<(Option<Tensor>, Vec<i64>, Device), RustBertError> {
    Ok(match (input_ids, input_embeddings) {
        (Some(_), Some(_)) => {
            return Err(RustBertError::ValueError(
                "Only one of input ids or input embeddings may be set".into(),
            ));
        }
        (Some(input_value), None) => (
            Some(input_value.apply(embeddings_matrix)),
            input_value.size(),
            input_value.device(),
        ),
        (None, Some(embeds)) => {
            let size = vec![embeds.size()[0], embeds.size()[1]];
            (None, size, embeds.device())
        }
        (None, None) => {
            return Err(RustBertError::ValueError(
                "At least one of input ids or input embeddings must be set".into(),
            ));
        }
    })
}

pub fn get_shape_and_device_from_ids_embeddings_pair(
    input_ids: Option<&Tensor>,
    input_embeddings: Option<&Tensor>,
) -> Result<(Vec<i64>, Device), RustBertError> {
    Ok(match (input_ids, input_embeddings) {
        (Some(_), Some(_)) => {
            return Err(RustBertError::ValueError(
                "Only one of input ids or input embeddings may be set".into(),
            ));
        }
        (Some(input_value), None) => (input_value.size(), input_value.device()),
        (None, Some(embeds)) => {
            let size = vec![embeds.size()[0], embeds.size()[1]];
            (size, embeds.device())
        }
        (None, None) => {
            return Err(RustBertError::ValueError(
                "At least one of input ids or input embeddings must be set".into(),
            ));
        }
    })
}