use std::sync::Arc;
use arrow_array::{
Array, ArrayRef, FixedSizeListArray, Int32Array, ListArray, PrimitiveArray, StructArray,
types::ArrowPrimitiveType,
};
use arrow_buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
use arrow_schema::{
ArrowError, DataType, Field,
extension::{
EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY, ExtensionType, FixedShapeTensor,
VariableShapeTensor,
},
};
use ndarray::{ArrayD, ArrayViewD, IxDyn};
use serde::{Deserialize, Serialize};
use crate::{element::NdarrowElement, error::NdarrowError};
#[derive(Debug, Deserialize, Serialize)]
struct FixedShapeTensorWireMetadata {
shape: Vec<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
dim_names: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
permutations: Option<Vec<usize>>,
}
#[derive(Debug, Deserialize, Serialize)]
struct VariableShapeTensorWireMetadata {
#[serde(default, skip_serializing_if = "Option::is_none")]
dim_names: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
permutations: Option<Vec<usize>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
uniform_shape: Option<Vec<Option<i32>>>,
}
fn require_extension_name(field: &Field, expected_name: &'static str) -> Result<(), NdarrowError> {
match field.extension_type_name() {
Some(name) if name == expected_name => Ok(()),
Some(name) => Err(NdarrowError::Arrow(ArrowError::InvalidArgumentError(format!(
"Field extension type mismatch, expected {expected_name}, found {name}",
)))),
None => Err(NdarrowError::Arrow(ArrowError::InvalidArgumentError(
"Field extension type name missing".to_owned(),
))),
}
}
pub(crate) fn field_with_extension_metadata(
field: Field,
extension_name: &'static str,
metadata_json: String,
) -> Field {
let mut metadata = field.metadata().clone();
metadata.insert(EXTENSION_TYPE_NAME_KEY.to_owned(), extension_name.to_owned());
metadata.insert(EXTENSION_TYPE_METADATA_KEY.to_owned(), metadata_json);
field.with_metadata(metadata)
}
pub(crate) fn parse_fixed_shape_extension(field: &Field) -> Result<FixedShapeTensor, NdarrowError> {
require_extension_name(field, FixedShapeTensor::NAME)?;
let raw_metadata =
field.extension_type_metadata().ok_or_else(|| NdarrowError::InvalidMetadata {
message: "arrow.fixed_shape_tensor metadata missing".to_owned(),
})?;
let metadata: FixedShapeTensorWireMetadata =
serde_json::from_str(raw_metadata).map_err(|e| NdarrowError::InvalidMetadata {
message: format!("arrow.fixed_shape_tensor metadata parse failed: {e}"),
})?;
let value_type = match field.data_type() {
DataType::FixedSizeList(item, _) => item.data_type().clone(),
data_type => {
return Err(NdarrowError::TypeMismatch {
message: format!(
"arrow.fixed_shape_tensor requires FixedSizeList storage, found {data_type}"
),
});
}
};
let extension = FixedShapeTensor::try_new(
value_type,
metadata.shape,
metadata.dim_names,
metadata.permutations,
)
.map_err(NdarrowError::from)?;
extension.supports_data_type(field.data_type()).map_err(NdarrowError::from)?;
Ok(extension)
}
pub(crate) fn parse_variable_shape_extension(
field: &Field,
) -> Result<VariableShapeTensor, NdarrowError> {
require_extension_name(field, VariableShapeTensor::NAME)?;
let raw_metadata =
field.extension_type_metadata().ok_or_else(|| NdarrowError::InvalidMetadata {
message: "arrow.variable_shape_tensor metadata missing".to_owned(),
})?;
let metadata: VariableShapeTensorWireMetadata =
serde_json::from_str(raw_metadata).map_err(|e| NdarrowError::InvalidMetadata {
message: format!("arrow.variable_shape_tensor metadata parse failed: {e}"),
})?;
let (value_type, dimensions) = match field.data_type() {
DataType::Struct(fields) => {
let data_field = fields.find("data").ok_or_else(|| NdarrowError::TypeMismatch {
message: "arrow.variable_shape_tensor storage missing 'data' field".to_owned(),
})?;
let shape_field = fields.find("shape").ok_or_else(|| NdarrowError::TypeMismatch {
message: "arrow.variable_shape_tensor storage missing 'shape' field".to_owned(),
})?;
let value_type = match data_field.1.data_type() {
DataType::List(item) => item.data_type().clone(),
data_type => {
return Err(NdarrowError::TypeMismatch {
message: format!(
"arrow.variable_shape_tensor 'data' field must be List, found {data_type}"
),
});
}
};
let dimensions = match shape_field.1.data_type() {
DataType::FixedSizeList(_, list_size) => usize::try_from(*list_size).map_err(
|_| NdarrowError::TypeMismatch {
message: format!(
"arrow.variable_shape_tensor shape list size must be non-negative, found {list_size}"
),
},
)?,
data_type => {
return Err(NdarrowError::TypeMismatch {
message: format!(
"arrow.variable_shape_tensor 'shape' field must be FixedSizeList, found {data_type}"
),
});
}
};
(value_type, dimensions)
}
data_type => {
return Err(NdarrowError::TypeMismatch {
message: format!(
"arrow.variable_shape_tensor requires Struct storage, found {data_type}"
),
});
}
};
let extension = VariableShapeTensor::try_new(
value_type,
dimensions,
metadata.dim_names,
metadata.permutations,
metadata.uniform_shape,
)
.map_err(NdarrowError::from)?;
extension.supports_data_type(field.data_type()).map_err(NdarrowError::from)?;
Ok(extension)
}
pub(crate) fn parse_fixed_shape_metadata(field: &Field) -> Result<Vec<usize>, NdarrowError> {
parse_fixed_shape_extension(field)?;
let raw_metadata =
field.extension_type_metadata().ok_or_else(|| NdarrowError::InvalidMetadata {
message: "arrow.fixed_shape_tensor metadata missing".to_owned(),
})?;
let metadata: FixedShapeTensorWireMetadata =
serde_json::from_str(raw_metadata).map_err(|e| NdarrowError::InvalidMetadata {
message: format!("arrow.fixed_shape_tensor metadata parse failed: {e}"),
})?;
Ok(metadata.shape)
}
pub(crate) struct VariableShapeTensorStorage<'a> {
pub data: &'a ListArray,
pub data_item_field: &'a Field,
pub shape_values: &'a Int32Array,
pub dimensions: usize,
pub uniform_shape: Option<Vec<Option<i32>>>,
}
impl<'a> VariableShapeTensorStorage<'a> {
#[must_use]
pub(crate) fn row_cursor(&self) -> VariableShapeTensorRowCursor<'a> {
VariableShapeTensorRowCursor {
index: 0,
len: self.data.len(),
data: self.data,
shape_values: self.shape_values,
dimensions: self.dimensions,
uniform_shape: self.uniform_shape.clone(),
}
}
}
pub(crate) struct VariableShapeTensorRow {
pub row: usize,
pub start: usize,
pub end: usize,
pub shape: Vec<usize>,
}
pub(crate) struct VariableShapeTensorRowCursor<'a> {
index: usize,
len: usize,
data: &'a ListArray,
shape_values: &'a Int32Array,
dimensions: usize,
uniform_shape: Option<Vec<Option<i32>>>,
}
fn tensor_offset_to_usize(offset: i32, row: usize, context: &str) -> Result<usize, NdarrowError> {
usize::try_from(offset).map_err(|_| NdarrowError::InvalidMetadata {
message: format!("negative {context} at row {row}: {offset}"),
})
}
fn decode_variable_shape_tensor_row(
row: usize,
data_offsets: &[i32],
shape_values: &[i32],
dimensions: usize,
uniform_shape: Option<&[Option<i32>]>,
) -> Result<VariableShapeTensorRow, NdarrowError> {
let start = tensor_offset_to_usize(data_offsets[row], row, "data offset")?;
let end = tensor_offset_to_usize(data_offsets[row + 1], row, "data end offset")?;
let shape_start = row.checked_mul(dimensions).ok_or_else(|| NdarrowError::ShapeMismatch {
message: format!("shape index overflow at row {row}"),
})?;
let shape_end = shape_start.checked_add(dimensions).ok_or_else(|| {
NdarrowError::ShapeMismatch { message: format!("shape range overflow at row {row}") }
})?;
let mut shape = Vec::with_capacity(dimensions);
for (dim_idx, raw) in shape_values[shape_start..shape_end].iter().copied().enumerate() {
let dim = usize::try_from(raw).map_err(|_| NdarrowError::ShapeMismatch {
message: format!("negative tensor dimension at row {row}, dim {dim_idx}: {raw}"),
})?;
if let Some(uniform_shape) = uniform_shape
&& let Some(expected) = uniform_shape[dim_idx]
{
let expected =
usize::try_from(expected).map_err(|_| NdarrowError::InvalidMetadata {
message: format!(
"uniform_shape contains negative dimension at index {dim_idx}: {expected}"
),
})?;
if dim != expected {
return Err(NdarrowError::ShapeMismatch {
message: format!(
"row {row} dimension {dim_idx} violates uniform_shape: expected {expected}, found {dim}"
),
});
}
}
shape.push(dim);
}
let required_len = shape
.iter()
.try_fold(1_usize, |acc, dim| acc.checked_mul(*dim))
.ok_or_else(|| NdarrowError::ShapeMismatch {
message: format!("row {row} shape product overflows usize: {shape:?}"),
})?;
if required_len != (end - start) {
return Err(NdarrowError::ShapeMismatch {
message: format!(
"row {row} shape product ({required_len}) does not match data length ({})",
end - start
),
});
}
Ok(VariableShapeTensorRow { row, start, end, shape })
}
impl VariableShapeTensorRowCursor<'_> {
pub(crate) fn next_row(&mut self) -> Option<Result<VariableShapeTensorRow, NdarrowError>> {
if self.index >= self.len {
return None;
}
let row = self.index;
self.index += 1;
Some(decode_variable_shape_tensor_row(
row,
self.data.value_offsets(),
self.shape_values.values().as_ref(),
self.dimensions,
self.uniform_shape.as_deref(),
))
}
}
pub(crate) fn variable_shape_tensor_storage<'a>(
field: &Field,
array: &'a StructArray,
) -> Result<VariableShapeTensorStorage<'a>, NdarrowError> {
let extension = parse_variable_shape_extension(field)?;
extension.supports_data_type(array.data_type()).map_err(NdarrowError::from)?;
let data = array
.column(0)
.as_any()
.downcast_ref::<ListArray>()
.expect("extension storage guarantees 'data' is ListArray");
if data.null_count() > 0 {
return Err(NdarrowError::NullsPresent { null_count: data.null_count() });
}
let data_item_field = match data.data_type() {
DataType::List(item) => item.as_ref(),
_ => unreachable!("validated variable-shape tensor storage guarantees List data"),
};
let shape = array
.column(1)
.as_any()
.downcast_ref::<FixedSizeListArray>()
.expect("extension storage guarantees 'shape' is FixedSizeListArray");
if shape.null_count() > 0 {
return Err(NdarrowError::NullsPresent { null_count: shape.null_count() });
}
let shape_values = shape
.values()
.as_any()
.downcast_ref::<Int32Array>()
.expect("extension storage guarantees variable shape inner Int32");
if shape_values.null_count() > 0 {
return Err(NdarrowError::NullsPresent { null_count: shape_values.null_count() });
}
Ok(VariableShapeTensorStorage {
data,
data_item_field,
shape_values,
dimensions: extension.dimensions(),
uniform_shape: extension.uniform_shapes().map(<[Option<i32>]>::to_vec),
})
}
pub struct VariableShapeTensorRowView<'a, T> {
row: usize,
values: &'a [T],
shape: Vec<usize>,
}
impl<'a, T> VariableShapeTensorRowView<'a, T> {
#[must_use]
pub fn row(&self) -> usize {
self.row
}
#[must_use]
pub fn shape(&self) -> &[usize] {
&self.shape
}
#[must_use]
pub fn values(&self) -> &'a [T] {
self.values
}
}
impl<'a, T> VariableShapeTensorRowView<'a, T>
where
T: NdarrowElement,
{
pub fn as_array_viewd(&self) -> Result<ArrayViewD<'a, T>, NdarrowError> {
ArrayViewD::from_shape(IxDyn(&self.shape), self.values).map_err(NdarrowError::from)
}
}
pub struct VariableShapeTensorBatchView<'a, T>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
nulls: Option<&'a NullBuffer>,
data: &'a ListArray,
data_values: &'a PrimitiveArray<T>,
shape_values: &'a Int32Array,
dimensions: usize,
uniform_shape: Option<Vec<Option<i32>>>,
}
impl<T> Clone for VariableShapeTensorBatchView<'_, T>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
fn clone(&self) -> Self {
Self {
nulls: self.nulls,
data: self.data,
data_values: self.data_values,
shape_values: self.shape_values,
dimensions: self.dimensions,
uniform_shape: self.uniform_shape.clone(),
}
}
}
impl<'a, T> VariableShapeTensorBatchView<'a, T>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
#[must_use]
pub fn len(&self) -> usize {
self.data.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn nulls(&self) -> Option<&'a NullBuffer> {
self.nulls
}
#[must_use]
pub fn dimensions(&self) -> usize {
self.dimensions
}
#[must_use]
pub fn uniform_shape(&self) -> Option<&[Option<i32>]> {
self.uniform_shape.as_deref()
}
#[must_use]
pub fn data_offsets(&self) -> &[i32] {
self.data.value_offsets()
}
#[must_use]
pub fn values(&self) -> &'a [T::Native] {
self.data_values.values().as_ref()
}
#[must_use]
pub fn shape_values(&self) -> &[i32] {
self.shape_values.values().as_ref()
}
pub fn row(
&self,
index: usize,
) -> Result<VariableShapeTensorRowView<'a, T::Native>, NdarrowError> {
if index >= self.len() {
return Err(NdarrowError::ShapeMismatch {
message: format!(
"row index {index} out of bounds for batch of length {}",
self.len()
),
});
}
let row = decode_variable_shape_tensor_row(
index,
self.data.value_offsets(),
self.shape_values.values().as_ref(),
self.dimensions,
self.uniform_shape.as_deref(),
)?;
Ok(VariableShapeTensorRowView {
row: row.row,
values: &self.data_values.values().as_ref()[row.start..row.end],
shape: row.shape,
})
}
#[must_use]
pub fn iter(&self) -> VariableShapeTensorIter<'a, T> {
VariableShapeTensorIter { batch: (*self).clone(), index: 0 }
}
}
impl<'a, T> IntoIterator for &'a VariableShapeTensorBatchView<'a, T>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
type IntoIter = VariableShapeTensorIter<'a, T>;
type Item = Result<(usize, ArrayViewD<'a, T::Native>), NdarrowError>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
pub fn variable_shape_tensor_batch_view<'a, T>(
field: &Field,
array: &'a StructArray,
) -> Result<VariableShapeTensorBatchView<'a, T>, NdarrowError>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
let storage = variable_shape_tensor_storage(field, array)?;
let data_values =
storage.data.values().as_any().downcast_ref::<PrimitiveArray<T>>().ok_or_else(|| {
NdarrowError::InnerTypeMismatch {
message: format!(
"expected variable tensor data values as {:?}, found {}",
T::DATA_TYPE,
storage.data.values().data_type()
),
}
})?;
if data_values.null_count() > 0 {
return Err(NdarrowError::NullsPresent { null_count: data_values.null_count() });
}
Ok(VariableShapeTensorBatchView {
nulls: array.nulls(),
data: storage.data,
data_values,
shape_values: storage.shape_values,
dimensions: storage.dimensions,
uniform_shape: storage.uniform_shape,
})
}
pub fn fixed_shape_tensor_as_array_viewd<'a, T>(
field: &Field,
array: &'a FixedSizeListArray,
) -> Result<ArrayViewD<'a, T::Native>, NdarrowError>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
if array.null_count() > 0 {
return Err(NdarrowError::NullsPresent { null_count: array.null_count() });
}
if !field.data_type().equals_datatype(array.data_type()) {
return Err(NdarrowError::TypeMismatch {
message: format!(
"field data type ({}) does not match array data type ({})",
field.data_type(),
array.data_type()
),
});
}
let tensor_shape = parse_fixed_shape_metadata(field)?;
let values = array.values().as_any().downcast_ref::<PrimitiveArray<T>>().ok_or_else(|| {
NdarrowError::InnerTypeMismatch {
message: format!(
"expected inner array type {:?}, found {}",
T::DATA_TYPE,
array.values().data_type()
),
}
})?;
if values.null_count() > 0 {
return Err(NdarrowError::NullsPresent { null_count: values.null_count() });
}
let mut full_shape = Vec::with_capacity(tensor_shape.len() + 1);
full_shape.push(array.len());
full_shape.extend_from_slice(&tensor_shape);
let slice: &[T::Native] = values.values().as_ref();
ArrayViewD::from_shape(IxDyn(&full_shape), slice).map_err(NdarrowError::from)
}
pub fn arrayd_to_fixed_shape_tensor<T>(
field_name: &str,
array: ArrayD<T>,
) -> Result<(Field, FixedSizeListArray), NdarrowError>
where
T: NdarrowElement,
{
let shape = array.shape().to_vec();
if shape.is_empty() {
return Err(NdarrowError::ShapeMismatch {
message: "ArrayD must have at least one dimension (batch axis)".to_owned(),
});
}
let batch = shape[0];
let tensor_shape = shape[1..].to_vec();
let element_count = if tensor_shape.is_empty() {
1
} else {
tensor_shape.iter().try_fold(1_usize, |acc, dim| acc.checked_mul(*dim)).ok_or_else(
|| NdarrowError::ShapeMismatch {
message: format!("tensor shape product overflows usize: {tensor_shape:?}"),
},
)?
};
let expected_len =
batch.checked_mul(element_count).ok_or_else(|| NdarrowError::ShapeMismatch {
message: format!("batch × tensor_shape product overflows usize: {shape:?}"),
})?;
if expected_len != array.len() {
return Err(NdarrowError::ShapeMismatch {
message: format!(
"array length ({}) does not match batch × tensor element count ({expected_len})",
array.len()
),
});
}
let standard =
if array.is_standard_layout() { array } else { array.as_standard_layout().into_owned() };
let total_len = standard.len();
let (mut raw_vec, offset) = standard.into_raw_vec_and_offset();
let start = offset.unwrap_or(0);
let vec = if start == 0 {
raw_vec.truncate(total_len);
raw_vec
} else {
raw_vec[start..start + total_len].to_vec()
};
let values_array = PrimitiveArray::<T::ArrowType>::new(ScalarBuffer::from(vec), None);
let item_field = Arc::new(Field::new("item", T::data_type(), false));
let value_length = i32::try_from(element_count).map_err(|_| NdarrowError::ShapeMismatch {
message: format!("tensor element count {element_count} exceeds Arrow i32 limits"),
})?;
let fsl = FixedSizeListArray::new(item_field, value_length, Arc::new(values_array), None);
let extension = FixedShapeTensor::try_new(T::data_type(), tensor_shape.clone(), None, None)
.map_err(NdarrowError::from)?;
extension.supports_data_type(fsl.data_type()).map_err(NdarrowError::from)?;
let metadata_json = serde_json::to_string(&FixedShapeTensorWireMetadata {
shape: tensor_shape,
dim_names: None,
permutations: None,
})
.map_err(|e| NdarrowError::InvalidMetadata {
message: format!("arrow.fixed_shape_tensor metadata serialization failed: {e}"),
})?;
let field = field_with_extension_metadata(
Field::new(field_name, fsl.data_type().clone(), false),
FixedShapeTensor::NAME,
metadata_json,
);
Ok((field, fsl))
}
pub struct VariableShapeTensorIter<'a, T>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
batch: VariableShapeTensorBatchView<'a, T>,
index: usize,
}
impl<'a, T> Iterator for VariableShapeTensorIter<'a, T>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
type Item = Result<(usize, ArrayViewD<'a, T::Native>), NdarrowError>;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.batch.len() {
return None;
}
let row = self.index;
self.index += 1;
Some(
self.batch
.row(row)
.and_then(|row_view| row_view.as_array_viewd().map(|view| (row, view))),
)
}
}
pub fn variable_shape_tensor_iter_masked<'a, T>(
field: &Field,
array: &'a StructArray,
) -> Result<(VariableShapeTensorIter<'a, T>, Option<&'a NullBuffer>), NdarrowError>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
let batch = variable_shape_tensor_batch_view(field, array)?;
let nulls = batch.nulls();
Ok((batch.iter(), nulls))
}
pub fn variable_shape_tensor_iter<'a, T>(
field: &Field,
array: &'a StructArray,
) -> Result<VariableShapeTensorIter<'a, T>, NdarrowError>
where
T: ArrowPrimitiveType,
T::Native: NdarrowElement,
{
if array.null_count() > 0 {
return Err(NdarrowError::NullsPresent { null_count: array.null_count() });
}
let (iter, _nulls) = variable_shape_tensor_iter_masked(field, array)?;
Ok(iter)
}
pub(crate) fn push_tensor_shape(
shape: &[usize],
row: usize,
uniform_shape: Option<&[Option<i32>]>,
packed_shapes: &mut Vec<i32>,
) -> Result<(), NdarrowError> {
for (dim_idx, dim) in shape.iter().copied().enumerate() {
let dim_i32 = i32::try_from(dim).map_err(|_| NdarrowError::ShapeMismatch {
message: format!(
"tensor dimension exceeds i32 limits at row {row}, dim {dim_idx}: {dim}"
),
})?;
if let Some(uniform) = uniform_shape {
if let Some(expected) = uniform[dim_idx] {
if expected != dim_i32 {
return Err(NdarrowError::ShapeMismatch {
message: format!(
"row {row} dimension {dim_idx} violates uniform_shape: expected {expected}, found {dim_i32}"
),
});
}
}
}
packed_shapes.push(dim_i32);
}
Ok(())
}
fn append_row_values<T>(
array: ArrayD<T>,
row: usize,
packed_values: &mut Vec<T>,
) -> Result<i32, NdarrowError>
where
T: NdarrowElement,
{
let standard =
if array.is_standard_layout() { array } else { array.as_standard_layout().into_owned() };
let element_count = standard.len();
let (mut raw_vec, offset) = standard.into_raw_vec_and_offset();
let start = offset.unwrap_or(0);
if start == 0 {
raw_vec.truncate(element_count);
packed_values.extend(raw_vec);
} else {
packed_values.extend_from_slice(&raw_vec[start..start + element_count]);
}
i32::try_from(element_count).map_err(|_| NdarrowError::ShapeMismatch {
message: format!(
"tensor row element count exceeds i32 limits at row {row}: {element_count}"
),
})
}
pub fn arrays_to_variable_shape_tensor<T>(
field_name: &str,
arrays: Vec<ArrayD<T>>,
uniform_shape: Option<Vec<Option<i32>>>,
) -> Result<(Field, StructArray), NdarrowError>
where
T: NdarrowElement,
{
if arrays.is_empty() {
return Err(NdarrowError::InvalidMetadata {
message: "arrays_to_variable_shape_tensor requires at least one tensor".to_owned(),
});
}
let dimensions = arrays[0].ndim();
if let Some(uniform) = &uniform_shape {
if uniform.len() != dimensions {
return Err(NdarrowError::InvalidMetadata {
message: format!(
"uniform_shape length mismatch: expected {dimensions}, found {}",
uniform.len()
),
});
}
}
let mut offsets = Vec::with_capacity(arrays.len() + 1);
offsets.push(0_i32);
let mut running_offset = 0_i32;
let mut packed_values = Vec::new();
let mut packed_shapes = Vec::with_capacity(arrays.len() * dimensions);
for (row, array) in arrays.into_iter().enumerate() {
if array.ndim() != dimensions {
return Err(NdarrowError::ShapeMismatch {
message: format!(
"all tensors must share rank {dimensions}; row {row} has rank {}",
array.ndim()
),
});
}
push_tensor_shape(array.shape(), row, uniform_shape.as_deref(), &mut packed_shapes)?;
let element_count_i32 = append_row_values(array, row, &mut packed_values)?;
running_offset = running_offset.checked_add(element_count_i32).ok_or_else(|| {
NdarrowError::ShapeMismatch {
message: "packed variable tensor offsets exceed i32 limits".to_owned(),
}
})?;
offsets.push(running_offset);
}
let data_offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
let data_values = PrimitiveArray::<T::ArrowType>::new(ScalarBuffer::from(packed_values), None);
let data_item_field = Arc::new(Field::new_list_field(T::data_type(), false));
let data_list: ArrayRef =
Arc::new(ListArray::new(data_item_field, data_offsets, Arc::new(data_values), None));
let shape_values = Int32Array::new(ScalarBuffer::from(packed_shapes), None);
let shape_item_field = Arc::new(Field::new("item", DataType::Int32, false));
let shape_fsl: ArrayRef = Arc::new(FixedSizeListArray::new(
shape_item_field,
i32::try_from(dimensions).map_err(|_| NdarrowError::ShapeMismatch {
message: format!("tensor rank exceeds i32 limits: {dimensions}"),
})?,
Arc::new(shape_values),
None,
));
let struct_fields = vec![
Field::new("data", data_list.data_type().clone(), false),
Field::new("shape", shape_fsl.data_type().clone(), false),
];
let struct_array =
StructArray::new(struct_fields.clone().into(), vec![data_list, shape_fsl], None);
let extension =
VariableShapeTensor::try_new(T::data_type(), dimensions, None, None, uniform_shape.clone())
.map_err(NdarrowError::from)?;
extension.supports_data_type(struct_array.data_type()).map_err(NdarrowError::from)?;
let metadata_json = serde_json::to_string(&VariableShapeTensorWireMetadata {
dim_names: None,
permutations: None,
uniform_shape,
})
.map_err(|e| NdarrowError::InvalidMetadata {
message: format!("arrow.variable_shape_tensor metadata serialization failed: {e}"),
})?;
let field = field_with_extension_metadata(
Field::new(field_name, struct_array.data_type().clone(), false),
VariableShapeTensor::NAME,
metadata_json,
);
Ok((field, struct_array))
}
#[cfg(test)]
mod tests {
use std::{collections::HashMap, sync::Arc};
use approx::assert_abs_diff_eq;
use arrow_array::{Float32Array, Float64Array, Int32Array, types::Float32Type};
use arrow_buffer::NullBuffer;
use arrow_schema::{
DataType, Field,
extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY},
};
use super::*;
#[test]
fn fixed_shape_tensor_roundtrip() {
let data: Vec<f32> = (0_u16..24_u16).map(|v| f32::from(v) * 0.5).collect();
let array = ArrayD::from_shape_vec(IxDyn(&[2, 3, 4]), data.clone()).unwrap();
let (field, fsl) = arrayd_to_fixed_shape_tensor("tensor", array).unwrap();
let view = fixed_shape_tensor_as_array_viewd::<Float32Type>(&field, &fsl).unwrap();
assert_eq!(view.shape(), &[2, 3, 4]);
for (actual, expected) in view.iter().zip(data.iter()) {
assert_abs_diff_eq!(*actual, *expected);
}
}
#[test]
fn fixed_shape_tensor_outbound_metadata_is_arrow_parseable() {
let array = ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![1.0_f32, 2.0, 3.0, 4.0]).unwrap();
let (field, _fsl) = arrayd_to_fixed_shape_tensor("tensor", array).unwrap();
let metadata = field.extension_type_metadata().unwrap();
assert!(!metadata.contains("\"dim_names\":null"));
assert!(!metadata.contains("\"permutations\":null"));
let extension = field.try_extension_type::<FixedShapeTensor>().unwrap();
assert_eq!(extension.dimensions(), 1);
assert_eq!(extension.list_size(), 2);
}
#[test]
fn fixed_shape_tensor_outbound_zero_copy_standard_layout() {
let data: Vec<f64> = (0_u16..12_u16).map(f64::from).collect();
let ptr = data.as_ptr();
let array = ArrayD::from_shape_vec(IxDyn(&[3, 4]), data).unwrap();
let (_field, fsl) = arrayd_to_fixed_shape_tensor("tensor", array).unwrap();
let inner = fsl.values().as_any().downcast_ref::<Float64Array>().unwrap();
assert_eq!(ptr, inner.values().as_ref().as_ptr());
}
#[test]
fn variable_shape_tensor_roundtrip() {
let a =
ArrayD::from_shape_vec(IxDyn(&[2, 3]), vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let b = ArrayD::from_shape_vec(IxDyn(&[1, 3]), vec![7.0_f32, 8.0, 9.0]).unwrap();
let (field, array) =
arrays_to_variable_shape_tensor("ragged", vec![a, b], Some(vec![None, Some(3)]))
.unwrap();
let mut iter = variable_shape_tensor_iter::<Float32Type>(&field, &array).unwrap();
let (row0, view0) = iter.next().unwrap().unwrap();
assert_eq!(row0, 0);
assert_eq!(view0.shape(), &[2, 3]);
assert_abs_diff_eq!(view0[[0, 0]], 1.0_f32);
assert_abs_diff_eq!(view0[[1, 2]], 6.0_f32);
let (row1, view1) = iter.next().unwrap().unwrap();
assert_eq!(row1, 1);
assert_eq!(view1.shape(), &[1, 3]);
assert_abs_diff_eq!(view1[[0, 0]], 7.0_f32);
assert_abs_diff_eq!(view1[[0, 2]], 9.0_f32);
assert!(iter.next().is_none());
}
#[test]
fn variable_shape_tensor_outbound_metadata_is_arrow_parseable() {
let a = ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![1.0_f32, 2.0, 3.0, 4.0]).unwrap();
let (field, _array) =
arrays_to_variable_shape_tensor("ragged", vec![a], Some(vec![None, Some(2)])).unwrap();
let metadata = field.extension_type_metadata().unwrap();
assert!(!metadata.contains("\"dim_names\":null"));
assert!(!metadata.contains("\"permutations\":null"));
let extension = field.try_extension_type::<VariableShapeTensor>().unwrap();
assert_eq!(extension.dimensions(), 2);
assert_eq!(extension.uniform_shapes(), Some(&[None, Some(2)][..]));
}
#[test]
fn variable_shape_tensor_uniform_shape_violation() {
let a =
ArrayD::from_shape_vec(IxDyn(&[2, 3]), vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let b = ArrayD::from_shape_vec(IxDyn(&[1, 4]), vec![7.0_f32, 8.0, 9.0, 10.0]).unwrap();
let err = arrays_to_variable_shape_tensor("ragged", vec![a, b], Some(vec![None, Some(3)]))
.unwrap_err();
assert!(matches!(err, NdarrowError::ShapeMismatch { .. }));
}
#[test]
fn variable_shape_tensor_iterator_zero_copy() {
let a = ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![1.0_f32, 2.0, 3.0, 4.0]).unwrap();
let b = ArrayD::from_shape_vec(IxDyn(&[1, 2]), vec![5.0_f32, 6.0]).unwrap();
let (field, array) = arrays_to_variable_shape_tensor("ragged", vec![a, b], None).unwrap();
let data_list =
array.column_by_name("data").unwrap().as_any().downcast_ref::<ListArray>().unwrap();
let data_values = data_list.values().as_any().downcast_ref::<Float32Array>().unwrap();
let ptr = data_values.values().as_ref().as_ptr();
let mut iter = variable_shape_tensor_iter::<Float32Type>(&field, &array).unwrap();
let (_row, view) = iter.next().unwrap().unwrap();
assert_eq!(view.as_ptr(), ptr);
}
#[test]
fn fixed_shape_tensor_rejects_missing_extension() {
let values = Float32Array::from(vec![1.0_f32, 2.0, 3.0, 4.0]);
let item_field = Arc::new(Field::new("item", DataType::Float32, false));
let fsl = FixedSizeListArray::new(item_field, 2, Arc::new(values), None);
let field = Field::new("tensor", fsl.data_type().clone(), false);
let err = fixed_shape_tensor_as_array_viewd::<Float32Type>(&field, &fsl).unwrap_err();
assert!(matches!(err, NdarrowError::Arrow(_)));
}
#[test]
fn fixed_shape_tensor_rejects_invalid_extension_metadata() {
let values = Float32Array::from(vec![1.0_f32, 2.0, 3.0, 4.0]);
let item_field = Arc::new(Field::new("item", DataType::Float32, false));
let fsl = FixedSizeListArray::new(item_field, 2, Arc::new(values), None);
let mut metadata = HashMap::new();
metadata.insert(EXTENSION_TYPE_NAME_KEY.to_owned(), FixedShapeTensor::NAME.to_owned());
metadata.insert(EXTENSION_TYPE_METADATA_KEY.to_owned(), "{bad json".to_owned());
let field = Field::new("tensor", fsl.data_type().clone(), false).with_metadata(metadata);
let err = fixed_shape_tensor_as_array_viewd::<Float32Type>(&field, &fsl).unwrap_err();
assert!(matches!(err, NdarrowError::InvalidMetadata { .. } | NdarrowError::Arrow(_)));
}
#[test]
fn fixed_shape_tensor_rejects_outer_nulls() {
let values = Float32Array::from(vec![1.0_f32, 2.0, 3.0, 4.0]);
let item_field = Arc::new(Field::new("item", DataType::Float32, false));
let nulls = NullBuffer::from(vec![true, false]);
let fsl = FixedSizeListArray::new(item_field, 2, Arc::new(values), Some(nulls));
let mut metadata = HashMap::new();
metadata.insert(EXTENSION_TYPE_NAME_KEY.to_owned(), FixedShapeTensor::NAME.to_owned());
metadata.insert(EXTENSION_TYPE_METADATA_KEY.to_owned(), r#"{"shape":[2]}"#.to_owned());
let field = Field::new("tensor", fsl.data_type().clone(), false).with_metadata(metadata);
let err = fixed_shape_tensor_as_array_viewd::<Float32Type>(&field, &fsl).unwrap_err();
assert!(matches!(err, NdarrowError::NullsPresent { .. }));
}
#[test]
fn fixed_shape_tensor_rejects_inner_type_mismatch() {
let values = Float64Array::from(vec![1.0_f64, 2.0, 3.0, 4.0]);
let item_field = Arc::new(Field::new("item", DataType::Float64, false));
let fsl = FixedSizeListArray::new(item_field, 2, Arc::new(values), None);
let fixed_ext = FixedShapeTensor::try_new(DataType::Float64, [2], None, None).unwrap();
let mut field = Field::new("tensor", fsl.data_type().clone(), false);
field.try_with_extension_type(fixed_ext).unwrap();
let err = fixed_shape_tensor_as_array_viewd::<Float32Type>(&field, &fsl).unwrap_err();
assert!(matches!(err, NdarrowError::InnerTypeMismatch { .. }));
}
#[test]
fn arrayd_to_fixed_shape_tensor_rejects_zero_dim() {
let scalar = ArrayD::from_elem(IxDyn(&[]), 1.0_f32);
let err = arrayd_to_fixed_shape_tensor("tensor", scalar).unwrap_err();
assert!(matches!(err, NdarrowError::ShapeMismatch { .. }));
}
#[test]
fn variable_shape_tensor_iter_rejects_struct_nulls() {
let a = ArrayD::from_shape_vec(IxDyn(&[1, 2]), vec![1.0_f32, 2.0]).unwrap();
let b = ArrayD::from_shape_vec(IxDyn(&[1, 2]), vec![3.0_f32, 4.0]).unwrap();
let (field, array) = arrays_to_variable_shape_tensor("ragged", vec![a, b], None).unwrap();
let with_nulls = StructArray::new(
array.fields().clone(),
array.columns().to_vec(),
Some(NullBuffer::from(vec![true, false])),
);
let result = variable_shape_tensor_iter::<Float32Type>(&field, &with_nulls);
assert!(matches!(result, Err(NdarrowError::NullsPresent { .. })));
}
#[test]
fn variable_shape_tensor_batch_view_exposes_columnar_buffers() {
let a = ArrayD::from_shape_vec(IxDyn(&[1, 2]), vec![1.0_f32, 2.0]).unwrap();
let b = ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![3.0_f32, 4.0, 5.0, 6.0]).unwrap();
let (field, array) =
arrays_to_variable_shape_tensor("ragged", vec![a, b], Some(vec![None, Some(2)]))
.unwrap();
let with_nulls = StructArray::new(
array.fields().clone(),
array.columns().to_vec(),
Some(NullBuffer::from(vec![true, false])),
);
let batch = variable_shape_tensor_batch_view::<Float32Type>(&field, &with_nulls).unwrap();
let nulls = batch.nulls().expect("outer null buffer");
let row = batch.row(1).unwrap();
let view = row.as_array_viewd().unwrap();
assert_eq!(batch.len(), 2);
assert!(!batch.is_empty());
assert_eq!(batch.dimensions(), 2);
assert_eq!(batch.uniform_shape(), Some(&[None, Some(2)][..]));
assert_eq!(batch.data_offsets(), &[0, 2, 6]);
assert_eq!(batch.shape_values(), &[1, 2, 2, 2]);
assert_eq!(batch.values(), &[1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
assert!(nulls.is_valid(0));
assert!(nulls.is_null(1));
assert_eq!(row.row(), 1);
assert_eq!(row.shape(), &[2, 2]);
assert_eq!(row.values(), &[3.0_f32, 4.0, 5.0, 6.0]);
assert_abs_diff_eq!(view[[1, 1]], 6.0_f32);
}
#[test]
fn variable_shape_tensor_iter_masked_preserves_outer_nulls() {
let a = ArrayD::from_shape_vec(IxDyn(&[1, 2]), vec![1.0_f32, 2.0]).unwrap();
let b = ArrayD::from_shape_vec(IxDyn(&[1, 2]), vec![3.0_f32, 4.0]).unwrap();
let (field, array) = arrays_to_variable_shape_tensor("ragged", vec![a, b], None).unwrap();
let with_nulls = StructArray::new(
array.fields().clone(),
array.columns().to_vec(),
Some(NullBuffer::from(vec![true, false])),
);
let (iter, nulls) =
variable_shape_tensor_iter_masked::<Float32Type>(&field, &with_nulls).unwrap();
let rows = iter.collect::<Result<Vec<_>, _>>().unwrap();
let nulls = nulls.expect("outer null buffer");
assert_eq!(rows.len(), 2);
assert!(nulls.is_valid(0));
assert!(nulls.is_null(1));
assert_eq!(rows[0].0, 0);
assert_eq!(rows[0].1.shape(), &[1, 2]);
assert_abs_diff_eq!(rows[0].1[[0, 1]], 2.0_f32);
assert_eq!(rows[1].0, 1);
}
#[test]
fn variable_shape_tensor_iter_rejects_negative_shape_dimension() {
let data_values = Float32Array::from(vec![1.0_f32, 2.0, 3.0]);
let data_offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0_i32, 3_i32]));
let data_item_field = Arc::new(Field::new_list_field(DataType::Float32, false));
let data: ArrayRef =
Arc::new(ListArray::new(data_item_field, data_offsets, Arc::new(data_values), None));
let shape_values = Int32Array::from(vec![-1_i32, 3_i32]);
let shape_item_field = Arc::new(Field::new("item", DataType::Int32, false));
let shape: ArrayRef =
Arc::new(FixedSizeListArray::new(shape_item_field, 2, Arc::new(shape_values), None));
let struct_fields = vec![
Field::new("data", data.data_type().clone(), false),
Field::new("shape", shape.data_type().clone(), false),
];
let array = StructArray::new(struct_fields.into(), vec![data, shape], None);
let ext = VariableShapeTensor::try_new(DataType::Float32, 2, None, None, None).unwrap();
let mut field = Field::new("ragged", array.data_type().clone(), false);
field.try_with_extension_type(ext).unwrap();
let mut iter = variable_shape_tensor_iter::<Float32Type>(&field, &array).unwrap();
let err = iter.next().unwrap().unwrap_err();
assert!(matches!(err, NdarrowError::ShapeMismatch { .. }));
}
#[test]
fn variable_shape_tensor_iter_rejects_negative_uniform_shape_metadata() {
let a = ArrayD::from_shape_vec(IxDyn(&[1, 2]), vec![1.0_f32, 2.0]).unwrap();
let (_field, array) = arrays_to_variable_shape_tensor("ragged", vec![a], None).unwrap();
let ext = VariableShapeTensor::try_new(
DataType::Float32,
2,
None,
None,
Some(vec![Some(-1), None]),
)
.unwrap();
let mut field = Field::new("ragged", array.data_type().clone(), false);
field.try_with_extension_type(ext).unwrap();
let mut iter = variable_shape_tensor_iter::<Float32Type>(&field, &array).unwrap();
let err = iter.next().unwrap().unwrap_err();
assert!(matches!(err, NdarrowError::InvalidMetadata { .. }));
}
#[test]
fn push_tensor_shape_rejects_uniform_mismatch() {
let mut packed = Vec::new();
let err =
push_tensor_shape(&[2, 4], 0, Some(&[Some(2), Some(3)]), &mut packed).unwrap_err();
assert!(matches!(err, NdarrowError::ShapeMismatch { .. }));
}
#[test]
fn push_tensor_shape_rejects_dimension_overflow() {
let mut packed = Vec::new();
let err = push_tensor_shape(&[usize::MAX], 0, None, &mut packed).unwrap_err();
assert!(matches!(err, NdarrowError::ShapeMismatch { .. }));
}
#[test]
fn append_row_values_handles_non_zero_offset() {
let base =
ArrayD::from_shape_vec(IxDyn(&[3, 2]), vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let sliced = base.slice_move(ndarray::s![1.., ..]).into_dyn();
let mut packed = Vec::new();
let count = append_row_values(sliced, 0, &mut packed).unwrap();
assert_eq!(count, 4);
assert_eq!(packed, vec![3.0_f64, 4.0, 5.0, 6.0]);
}
#[test]
fn arrays_to_variable_shape_tensor_rejects_empty_input() {
let err = arrays_to_variable_shape_tensor::<f32>("ragged", vec![], None).unwrap_err();
assert!(matches!(err, NdarrowError::InvalidMetadata { .. }));
}
#[test]
fn arrays_to_variable_shape_tensor_rejects_uniform_shape_rank_mismatch() {
let a = ArrayD::from_shape_vec(IxDyn(&[1, 2]), vec![1.0_f32, 2.0]).unwrap();
let err = arrays_to_variable_shape_tensor(
"ragged",
vec![a],
Some(vec![Some(1), Some(2), Some(3)]),
)
.unwrap_err();
assert!(matches!(err, NdarrowError::InvalidMetadata { .. }));
}
#[test]
fn arrays_to_variable_shape_tensor_rejects_rank_mismatch() {
let a = ArrayD::from_shape_vec(IxDyn(&[1, 2]), vec![1.0_f64, 2.0]).unwrap();
let b = ArrayD::from_shape_vec(IxDyn(&[1, 1, 1]), vec![3.0_f64]).unwrap();
let err = arrays_to_variable_shape_tensor("ragged", vec![a, b], None).unwrap_err();
assert!(matches!(err, NdarrowError::ShapeMismatch { .. }));
}
#[test]
fn variable_shape_tensor_iter_rejects_shape_product_mismatch() {
let data_values = Float32Array::from(vec![1.0_f32, 2.0, 3.0]);
let data_offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0_i32, 3_i32]));
let data_item_field = Arc::new(Field::new_list_field(DataType::Float32, false));
let data: ArrayRef =
Arc::new(ListArray::new(data_item_field, data_offsets, Arc::new(data_values), None));
let shape_values = Int32Array::from(vec![1_i32, 2_i32]);
let shape_item_field = Arc::new(Field::new("item", DataType::Int32, false));
let shape: ArrayRef =
Arc::new(FixedSizeListArray::new(shape_item_field, 2, Arc::new(shape_values), None));
let struct_fields = vec![
Field::new("data", data.data_type().clone(), false),
Field::new("shape", shape.data_type().clone(), false),
];
let array = StructArray::new(struct_fields.into(), vec![data, shape], None);
let ext = VariableShapeTensor::try_new(DataType::Float32, 2, None, None, None).unwrap();
let mut field = Field::new("ragged", array.data_type().clone(), false);
field.try_with_extension_type(ext).unwrap();
let mut iter = variable_shape_tensor_iter::<Float32Type>(&field, &array).unwrap();
let err = iter.next().unwrap().unwrap_err();
assert!(matches!(err, NdarrowError::ShapeMismatch { .. }));
}
#[test]
fn variable_shape_tensor_iter_rejects_wrong_extension_name() {
let a = ArrayD::from_shape_vec(IxDyn(&[1, 2]), vec![1.0_f32, 2.0]).unwrap();
let (_field, array) = arrays_to_variable_shape_tensor("ragged", vec![a], None).unwrap();
let mut metadata = HashMap::new();
metadata.insert(EXTENSION_TYPE_NAME_KEY.to_owned(), "ndarrow.not_tensor".to_owned());
metadata.insert(EXTENSION_TYPE_METADATA_KEY.to_owned(), "{}".to_owned());
let bad_field =
Field::new("ragged", array.data_type().clone(), false).with_metadata(metadata);
let result = variable_shape_tensor_iter::<Float32Type>(&bad_field, &array);
assert!(matches!(result, Err(NdarrowError::Arrow(_))));
}
}