ndarrow 0.0.3

Zero-copy bridge between Apache Arrow and ndarray
Documentation
use std::{collections::HashMap, sync::Arc};

use arrow_array::{
    Array, ArrayRef, FixedSizeListArray, Float32Array, Int32Array, ListArray, StructArray,
};
use arrow_buffer::{OffsetBuffer, ScalarBuffer};
use arrow_schema::{
    DataType, Field,
    extension::{
        EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY, ExtensionType, FixedShapeTensor,
        VariableShapeTensor,
    },
};
use ndarray::{ArrayD, IxDyn};
use ndarrow::{
    arrayd_to_fixed_shape_tensor, arrays_to_variable_shape_tensor, deserialize_registered_extension,
};

fn field_with_extension(
    data_type: DataType,
    extension_name: &str,
    metadata_json: Option<&str>,
) -> Field {
    let mut metadata = HashMap::new();
    metadata.insert(EXTENSION_TYPE_NAME_KEY.to_owned(), extension_name.to_owned());
    if let Some(metadata_json) = metadata_json {
        metadata.insert(EXTENSION_TYPE_METADATA_KEY.to_owned(), metadata_json.to_owned());
    }
    Field::new("field", data_type, false).with_metadata(metadata)
}

fn simple_list_f32_type() -> DataType {
    DataType::List(Arc::new(Field::new("item", DataType::Float32, false)))
}

fn simple_shape_type(dimensions: i32) -> DataType {
    DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), dimensions)
}

#[test]
fn fixed_shape_extension_rejects_missing_metadata() {
    let tensor =
        ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![1.0_f64, 2.0, 3.0, 4.0]).expect("shape valid");
    let (_field, storage) =
        arrayd_to_fixed_shape_tensor("tensor", tensor).expect("fixed-shape tensor should build");
    let field = field_with_extension(storage.data_type().clone(), FixedShapeTensor::NAME, None);

    let err = deserialize_registered_extension(&field).expect_err("missing metadata must fail");
    let message = err.to_string();
    assert!(message.contains("metadata missing"));
}

#[test]
fn fixed_shape_extension_rejects_non_fixed_size_list_storage() {
    let field =
        field_with_extension(DataType::Float64, FixedShapeTensor::NAME, Some(r#"{"shape":[2]}"#));

    let err = deserialize_registered_extension(&field).expect_err("wrong storage must fail");
    let message = err.to_string();
    assert!(message.contains("requires FixedSizeList storage"));
}

#[test]
fn variable_shape_extension_rejects_missing_metadata() {
    let arrays = vec![
        ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.0_f32, 2.0]).expect("shape valid"),
        ArrayD::from_shape_vec(IxDyn(&[2]), vec![3.0_f32, 4.0]).expect("shape valid"),
    ];
    let (_field, storage) = arrays_to_variable_shape_tensor("ragged", arrays, None)
        .expect("variable tensor should build");
    let field = field_with_extension(storage.data_type().clone(), VariableShapeTensor::NAME, None);

    let err = deserialize_registered_extension(&field).expect_err("missing metadata must fail");
    let message = err.to_string();
    assert!(message.contains("metadata missing"));
}

#[test]
fn variable_shape_extension_rejects_non_struct_storage() {
    let field = field_with_extension(DataType::Float32, VariableShapeTensor::NAME, Some(r"{}"));

    let err =
        deserialize_registered_extension(&field).expect_err("wrong top-level storage must fail");
    let message = err.to_string();
    assert!(message.contains("requires Struct storage"));
}

#[test]
fn variable_shape_extension_rejects_missing_data_field() {
    let field = field_with_extension(
        DataType::Struct(vec![Field::new("shape", simple_shape_type(1), false)].into()),
        VariableShapeTensor::NAME,
        Some(r"{}"),
    );

    let err = deserialize_registered_extension(&field).expect_err("missing data field must fail");
    let message = err.to_string();
    assert!(message.contains("missing 'data' field"));
}

#[test]
fn variable_shape_extension_rejects_missing_shape_field() {
    let field = field_with_extension(
        DataType::Struct(vec![Field::new("data", simple_list_f32_type(), false)].into()),
        VariableShapeTensor::NAME,
        Some(r"{}"),
    );

    let err = deserialize_registered_extension(&field).expect_err("missing shape field must fail");
    let message = err.to_string();
    assert!(message.contains("missing 'shape' field"));
}

#[test]
fn variable_shape_extension_rejects_wrong_data_field_type() {
    let field = field_with_extension(
        DataType::Struct(
            vec![
                Field::new("data", DataType::Float32, false),
                Field::new("shape", simple_shape_type(1), false),
            ]
            .into(),
        ),
        VariableShapeTensor::NAME,
        Some(r"{}"),
    );

    let err =
        deserialize_registered_extension(&field).expect_err("wrong data field type must fail");
    let message = err.to_string();
    assert!(message.contains("'data' field must be List"));
}

#[test]
fn variable_shape_extension_rejects_wrong_shape_field_type() {
    let field = field_with_extension(
        DataType::Struct(
            vec![
                Field::new("data", simple_list_f32_type(), false),
                Field::new(
                    "shape",
                    DataType::List(Arc::new(Field::new("item", DataType::Int32, false))),
                    false,
                ),
            ]
            .into(),
        ),
        VariableShapeTensor::NAME,
        Some(r"{}"),
    );

    let err =
        deserialize_registered_extension(&field).expect_err("wrong shape field type must fail");
    let message = err.to_string();
    assert!(message.contains("'shape' field must be FixedSizeList"));
}

#[test]
fn variable_shape_extension_rejects_invalid_metadata_json() {
    let field = field_with_extension(
        DataType::Struct(
            vec![
                Field::new("data", simple_list_f32_type(), false),
                Field::new("shape", simple_shape_type(1), false),
            ]
            .into(),
        ),
        VariableShapeTensor::NAME,
        Some("{invalid"),
    );

    let err = deserialize_registered_extension(&field).expect_err("invalid metadata must fail");
    let message = err.to_string();
    assert!(message.contains("metadata parse failed"));
}

#[test]
fn variable_shape_tensor_iter_rejects_field_storage_mismatch_after_parse() {
    let data_values = Float32Array::new(ScalarBuffer::from(vec![1.0_f32, 2.0]), None);
    let data_offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0_i32, 2_i32]));
    let data_array: ArrayRef = Arc::new(ListArray::new(
        Arc::new(Field::new("item", DataType::Float32, false)),
        data_offsets,
        Arc::new(data_values),
        None,
    ));

    let shape_values = Int32Array::new(ScalarBuffer::from(vec![2_i32]), None);
    let shape_array: ArrayRef = Arc::new(FixedSizeListArray::new(
        Arc::new(Field::new("item", DataType::Int32, false)),
        1,
        Arc::new(shape_values),
        None,
    ));

    let array = StructArray::new(
        vec![
            Field::new("data", data_array.data_type().clone(), false),
            Field::new("shape", shape_array.data_type().clone(), false),
        ]
        .into(),
        vec![data_array, shape_array],
        None,
    );

    let field = field_with_extension(
        DataType::Struct(
            vec![
                Field::new("data", simple_list_f32_type(), false),
                Field::new("shape", simple_shape_type(2), false),
            ]
            .into(),
        ),
        VariableShapeTensor::NAME,
        Some(r"{}"),
    );

    let result =
        ndarrow::variable_shape_tensor_iter::<arrow_array::types::Float32Type>(&field, &array);
    assert!(result.is_err(), "field/array storage mismatch must fail");
    let err = result.err().expect("field/array storage mismatch must fail");
    let message = err.to_string();
    assert!(message.contains("data type mismatch"));
}