encoderfile 0.4.0-rc.1

Distribute and run transformer encoders with a single file.
Documentation
use mlua::prelude::*;

pub fn table_to_vec(t: &LuaTable, shape: &mut Vec<usize>) -> Result<Vec<f32>, LuaError> {
    let mut data = Vec::new();

    let mut len = 0;
    for pair in t.clone().pairs::<LuaValue, LuaValue>() {
        let (_, v) = pair?;
        len += 1;
        match v {
            LuaValue::Table(subt) => {
                let mut subshape = Vec::new();
                let subdata = table_to_vec(&subt, &mut subshape)?;
                if shape.is_empty() {
                    shape.extend(std::iter::once(len).chain(subshape));
                } else {
                    shape[0] = len;
                }
                data.extend(subdata);
            }
            LuaValue::Integer(i) => data.push(i as f32),
            LuaValue::Number(n) => data.push(n as f32),
            _ => {
                return Err(LuaError::FromLuaConversionError {
                    from: v.type_name(),
                    to: "f32 or nested table".to_string(),
                    message: Some("Invalid value in tensor table".into()),
                });
            }
        }
    }

    if shape.is_empty() {
        shape.push(len);
    }

    Ok(data)
}

#[cfg(test)]
mod tests {
    use super::*;
    use mlua::Lua;

    fn run_table_to_vec(lua_code: &str) -> (Vec<f32>, Vec<usize>) {
        let lua = Lua::new();
        let t: LuaTable = lua.load(lua_code).eval().unwrap();
        let mut shape = Vec::new();
        let data = table_to_vec(&t, &mut shape).unwrap();
        (data, shape)
    }

    #[test]
    fn test_flat_table() {
        let (data, shape) = run_table_to_vec("{1, 2, 3, 4}");
        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
        assert_eq!(shape, vec![4]);
    }

    #[test]
    fn test_nested_table_2d() {
        let (data, shape) = run_table_to_vec("{{1, 2, 3}, {4, 5, 6}}");
        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
        assert_eq!(shape, vec![2, 3]);
    }

    #[test]
    fn test_nested_table_3d() {
        let (data, shape) = run_table_to_vec("{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}");
        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
        assert_eq!(shape, vec![2, 2, 2]);
    }

    #[test]
    fn test_mixed_number_types() {
        let (data, shape) = run_table_to_vec("{1, 2.5, 3, 4.5}");
        assert_eq!(data, vec![1.0, 2.5, 3.0, 4.5]);
        assert_eq!(shape, vec![4]);
    }

    #[test]
    fn test_invalid_value() {
        let lua = Lua::new();
        let t: LuaTable = lua.load("{1, 'bad', 3}").eval().unwrap();
        let mut shape = Vec::new();
        let err = table_to_vec(&t, &mut shape).unwrap_err();
        if let LuaError::FromLuaConversionError { message, .. } = err {
            assert!(message.unwrap().contains("Invalid value"));
        } else {
            panic!("Expected FromLuaConversionError");
        }
    }

    #[test]
    fn test_empty_table() {
        let (data, shape) = run_table_to_vec("{}");
        assert_eq!(data, Vec::<f32>::new());
        assert_eq!(shape, vec![0]);
    }
}