rust-bert 0.23.0

Ready-to-use NLP pipelines and language models
Documentation
use crate::RustBertError;
use tch::{Kind, Scalar};

pub(crate) fn get_positive_infinity(kind: Kind) -> Result<Scalar, RustBertError> {
    Ok(match kind {
        Kind::Uint8 => Scalar::int(u8::MAX.into()),
        Kind::Int8 => Scalar::int(i8::MAX.into()),
        Kind::Int16 => Scalar::int(i16::MAX.into()),
        Kind::Int => Scalar::int(i32::MAX.into()),
        Kind::Int64 => Scalar::int(i64::MAX),
        Kind::Half => Scalar::float(half::f16::INFINITY.into()),
        Kind::Float => Scalar::float(f32::INFINITY.into()),
        Kind::BFloat16 => Scalar::float(half::bf16::INFINITY.into()),
        Kind::Double => Scalar::float(f64::INFINITY),
        _ => {
            return Err(RustBertError::ValueError(format!(
                "Type not supported: attempted to get positive infinity for {kind:?}",
            )))
        }
    })
}

pub(crate) fn get_negative_infinity(kind: Kind) -> Result<Scalar, RustBertError> {
    Ok(match kind {
        Kind::Uint8 => Scalar::int(u8::MIN.into()),
        Kind::Int8 => Scalar::int(i8::MIN.into()),
        Kind::Int16 => Scalar::int(i16::MIN.into()),
        Kind::Int => Scalar::int(i32::MIN.into()),
        Kind::Int64 => Scalar::int(i64::MIN),
        Kind::Half => Scalar::float(half::f16::NEG_INFINITY.into()),
        Kind::Float => Scalar::float(f32::NEG_INFINITY.into()),
        Kind::BFloat16 => Scalar::float(half::bf16::NEG_INFINITY.into()),
        Kind::Double => Scalar::float(f64::NEG_INFINITY),
        _ => {
            return Err(RustBertError::ValueError(format!(
                "Type not supported: attempted to get negative infinity for {kind:?}",
            )))
        }
    })
}

pub(crate) fn get_min(kind: Kind) -> Result<Scalar, RustBertError> {
    Ok(match kind {
        Kind::Uint8 => Scalar::int(u8::MIN.into()),
        Kind::Int8 => Scalar::int(i8::MIN.into()),
        Kind::Int16 => Scalar::int(i16::MIN.into()),
        Kind::Int => Scalar::int(i32::MIN.into()),
        Kind::Int64 => Scalar::int(i64::MIN),
        Kind::Half => Scalar::float(half::f16::MIN.into()),
        Kind::Float => Scalar::float(f32::MIN.into()),
        Kind::BFloat16 => Scalar::float(half::bf16::MIN.into()),
        Kind::Double => Scalar::float(f64::MIN),
        _ => {
            return Err(RustBertError::ValueError(format!(
                "Type not supported: attempted to get min for {kind:?}",
            )))
        }
    })
}