encoderfile 0.6.2

Distribute and run transformer encoders with a single file.
Documentation
use crate::{common::model_type, error::ApiError};

use super::{super::tensor::Tensor, Postprocessor, Transform};
use ndarray::{Array3, Ix3};

impl Postprocessor for Transform<model_type::TokenClassification> {
    type Input = Array3<f32>;
    type Output = Array3<f32>;

    fn postprocess(&self, data: Self::Input) -> Result<Self::Output, ApiError> {
        let func = match self.postprocessor() {
            Some(p) => p,
            None => return Ok(data),
        };

        let expected_shape = data.shape().to_owned();

        let tensor = Tensor(data.into_dyn());

        let result = func
            .call::<Tensor>(tensor)
            .map_err(|e| ApiError::LuaError(e.to_string()))?
            .into_inner()
            .into_dimensionality::<Ix3>().map_err(|e| {
                tracing::error!("Failed to cast array into Ix3: {e}. Check your lua transform to make sure it returns a tensor of shape [batch_size, seq_len, num_classes]");
                ApiError::LuaError("Error postprocessing token classifications".to_string())
            })?;

        let result_shape = result.shape();

        if expected_shape.as_slice() != result_shape {
            tracing::error!(
                "Transform error: expected tensor of shape {:?}, got tensor of shape {:?}",
                expected_shape.as_slice(),
                result_shape
            );

            return Err(ApiError::LuaError(
                "Error postprocessing token classifications".to_string(),
            ));
        }

        Ok(result)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::transforms::DEFAULT_LIBS;

    #[test]
    fn test_token_cls_no_transform() {
        let engine = Transform::<model_type::TokenClassification>::new(
            DEFAULT_LIBS.to_vec(),
            Some("".to_string()),
        )
        .expect("Failed to create Transform");

        let arr = ndarray::Array3::<f32>::from_elem((32, 16, 2), 2.0);

        let result = engine.postprocess(arr.clone()).expect("Failed");

        assert_eq!(arr, result);
    }

    #[test]
    fn test_token_cls_identity_transform() {
        let engine = Transform::<model_type::TokenClassification>::new(
            DEFAULT_LIBS.to_vec(),
            Some(
                r##"
        function Postprocess(arr)
            return arr
        end
        "##
                .to_string(),
            ),
        )
        .expect("Failed to create engine");

        let arr = ndarray::Array3::<f32>::from_elem((16, 32, 2), 2.0);

        let result = engine.postprocess(arr.clone()).expect("Failed");

        assert_eq!(arr, result);
    }

    #[test]
    fn test_token_cls_transform_bad_fn() {
        let engine = Transform::<model_type::TokenClassification>::new(
            DEFAULT_LIBS.to_vec(),
            Some(
                r##"
        function Postprocess(arr)
            return 1
        end
        "##
                .to_string(),
            ),
        )
        .expect("Failed to create engine");

        let arr = ndarray::Array3::<f32>::from_elem((16, 32, 2), 2.0);

        let result = engine.postprocess(arr.clone());

        assert!(result.is_err())
    }

    #[test]
    fn test_bad_dimensionality_transform_postprocessing() {
        let engine = Transform::<model_type::TokenClassification>::new(
            DEFAULT_LIBS.to_vec(),
            Some(
                r##"
        function Postprocess(x)
            return x:sum_axis(1)
        end
        "##
                .to_string(),
            ),
        )
        .unwrap();

        let arr = ndarray::Array3::<f32>::from_elem((3, 3, 3), 2.0);
        let result = engine.postprocess(arr.clone());

        assert!(result.is_err());

        if let Err(e) = result {
            match e {
                ApiError::LuaError(s) => {
                    assert!(s.contains("Error postprocessing token classifications"))
                }
                _ => panic!("Didn't return lua error"),
            }
        }
    }
}