rust-bert 0.23.0

Ready-to-use NLP pipelines and language models
Documentation
use crate::RustBertError;
use ndarray::{ArrayBase, ArrayD, CowArray, CowRepr, IxDyn};

use ort::{Session, Value};
use std::convert::{TryFrom, TryInto};
use tch::{Kind, Tensor};

pub(crate) fn ort_tensor_to_tch(ort_tensor: &Value) -> Result<Tensor, RustBertError> {
    let ort_tensor = ort_tensor.try_extract::<f32>()?.view().to_owned();
    Ok(Tensor::try_from(ort_tensor)?)
}

pub(crate) fn array_to_ort<'a>(
    session: &Session,
    array: &'a TypedArray<'a>,
) -> Result<Value<'a>, RustBertError> {
    match &array {
        TypedArray::I64(array) => Ok(Value::from_array(session.allocator(), array)?),
        TypedArray::F32(array) => Ok(Value::from_array(session.allocator(), array)?),
        TypedArray::I32(array) => Ok(Value::from_array(session.allocator(), array)?),
        TypedArray::F64(array) => Ok(Value::from_array(session.allocator(), array)?),
        TypedArray::F16(array) => Ok(Value::from_array(session.allocator(), array)?),
        TypedArray::I16(array) => Ok(Value::from_array(session.allocator(), array)?),
        TypedArray::I8(array) => Ok(Value::from_array(session.allocator(), array)?),
        TypedArray::UI8(array) => Ok(Value::from_array(session.allocator(), array)?),
        TypedArray::BF16(array) => Ok(Value::from_array(session.allocator(), array)?),
    }
}

pub(crate) enum TypedArray<'a> {
    I64(ArrayBase<CowRepr<'a, i64>, IxDyn>),
    F32(ArrayBase<CowRepr<'a, f32>, IxDyn>),
    I32(ArrayBase<CowRepr<'a, i32>, IxDyn>),
    F64(ArrayBase<CowRepr<'a, f64>, IxDyn>),
    F16(ArrayBase<CowRepr<'a, half::f16>, IxDyn>),
    I16(ArrayBase<CowRepr<'a, i16>, IxDyn>),
    I8(ArrayBase<CowRepr<'a, i8>, IxDyn>),
    UI8(ArrayBase<CowRepr<'a, u8>, IxDyn>),
    BF16(ArrayBase<CowRepr<'a, half::bf16>, IxDyn>),
}

pub(crate) fn tch_tensor_to_ndarray(tch_tensor: &Tensor) -> Result<TypedArray, RustBertError> {
    let kind = tch_tensor.kind();
    Ok(match kind {
        Kind::Int64 => {
            let array: ArrayD<i64> = tch_tensor.try_into()?;
            TypedArray::I64(CowArray::from(array))
        }
        Kind::Float => {
            let array: ArrayD<f32> = tch_tensor.try_into()?;
            TypedArray::F32(CowArray::from(array))
        }
        Kind::Int => {
            let array: ArrayD<i32> = tch_tensor.try_into()?;
            TypedArray::I32(CowArray::from(array))
        }
        Kind::Double => {
            let array: ArrayD<f64> = tch_tensor.try_into()?;
            TypedArray::F64(CowArray::from(array))
        }
        Kind::Half => {
            let array: ArrayD<half::f16> = tch_tensor.try_into()?;
            TypedArray::F16(CowArray::from(array))
        }
        Kind::Int16 => {
            let array: ArrayD<i16> = tch_tensor.try_into()?;
            TypedArray::I16(CowArray::from(array))
        }
        Kind::Int8 => {
            let array: ArrayD<i8> = tch_tensor.try_into()?;
            TypedArray::I8(CowArray::from(array))
        }
        Kind::Uint8 => {
            let array: ArrayD<u8> = tch_tensor.try_into()?;
            TypedArray::UI8(CowArray::from(array))
        }
        Kind::BFloat16 => {
            let array: ArrayD<half::bf16> = tch_tensor.try_into()?;
            TypedArray::BF16(CowArray::from(array))
        }
        _ => {
            return Err(RustBertError::ValueError(format!(
                "Type not supported: attempted to get convert torch tensor to ndarray for {kind:?}",
            )))
        }
    })
}