use serde::{Deserialize, Serialize};
use crate::{extension::ExtensionType, ArrowError, DataType};
#[derive(Debug, Clone, PartialEq)]
pub struct FixedShapeTensor {
value_type: DataType,
metadata: FixedShapeTensorMetadata,
}
impl FixedShapeTensor {
pub fn try_new(
value_type: DataType,
shape: impl IntoIterator<Item = usize>,
dimension_names: Option<Vec<String>>,
permutations: Option<Vec<usize>>,
) -> Result<Self, ArrowError> {
FixedShapeTensorMetadata::try_new(shape, dimension_names, permutations).map(|metadata| {
Self {
value_type,
metadata,
}
})
}
pub fn value_type(&self) -> &DataType {
&self.value_type
}
pub fn list_size(&self) -> usize {
self.metadata.list_size()
}
pub fn dimensions(&self) -> usize {
self.metadata.dimensions()
}
pub fn dimension_names(&self) -> Option<&[String]> {
self.metadata.dimension_names()
}
pub fn permutations(&self) -> Option<&[usize]> {
self.metadata.permutations()
}
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct FixedShapeTensorMetadata {
shape: Vec<usize>,
dim_names: Option<Vec<String>>,
permutations: Option<Vec<usize>>,
}
impl FixedShapeTensorMetadata {
pub fn try_new(
shape: impl IntoIterator<Item = usize>,
dimension_names: Option<Vec<String>>,
permutations: Option<Vec<usize>>,
) -> Result<Self, ArrowError> {
let shape = shape.into_iter().collect::<Vec<_>>();
let dimensions = shape.len();
let dim_names = dimension_names.map(|dimension_names| {
if dimension_names.len() != dimensions {
Err(ArrowError::InvalidArgumentError(format!(
"FixedShapeTensor dimension names size mismatch, expected {dimensions}, found {}", dimension_names.len()
)))
} else {
Ok(dimension_names)
}
}).transpose()?;
let permutations = permutations
.map(|permutations| {
if permutations.len() != dimensions {
Err(ArrowError::InvalidArgumentError(format!(
"FixedShapeTensor permutations size mismatch, expected {dimensions}, found {}",
permutations.len()
)))
} else {
let mut sorted_permutations = permutations.clone();
sorted_permutations.sort_unstable();
if (0..dimensions).zip(sorted_permutations).any(|(a, b)| a != b) {
Err(ArrowError::InvalidArgumentError(format!(
"FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}"
)))
} else {
Ok(permutations)
}
}
})
.transpose()?;
Ok(Self {
shape,
dim_names,
permutations,
})
}
pub fn list_size(&self) -> usize {
self.shape.iter().product()
}
pub fn dimensions(&self) -> usize {
self.shape.len()
}
pub fn dimension_names(&self) -> Option<&[String]> {
self.dim_names.as_ref().map(AsRef::as_ref)
}
pub fn permutations(&self) -> Option<&[usize]> {
self.permutations.as_ref().map(AsRef::as_ref)
}
}
impl ExtensionType for FixedShapeTensor {
const NAME: &'static str = "arrow.fixed_shape_tensor";
type Metadata = FixedShapeTensorMetadata;
fn metadata(&self) -> &Self::Metadata {
&self.metadata
}
fn serialize_metadata(&self) -> Option<String> {
Some(serde_json::to_string(&self.metadata).expect("metadata serialization"))
}
fn deserialize_metadata(metadata: Option<&str>) -> Result<Self::Metadata, ArrowError> {
metadata.map_or_else(
|| {
Err(ArrowError::InvalidArgumentError(
"FixedShapeTensor extension types requires metadata".to_owned(),
))
},
|value| {
serde_json::from_str(value).map_err(|e| {
ArrowError::InvalidArgumentError(format!(
"FixedShapeTensor metadata deserialization failed: {e}"
))
})
},
)
}
fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> {
let expected = DataType::new_fixed_size_list(
self.value_type.clone(),
i32::try_from(self.list_size()).expect("overflow"),
false,
);
data_type
.equals_datatype(&expected)
.then_some(())
.ok_or_else(|| {
ArrowError::InvalidArgumentError(format!(
"FixedShapeTensor data type mismatch, expected {expected}, found {data_type}"
))
})
}
fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError> {
match data_type {
DataType::FixedSizeList(field, list_size) if !field.is_nullable() => {
let metadata = FixedShapeTensorMetadata::try_new(
metadata.shape,
metadata.dim_names,
metadata.permutations,
)?;
let expected_size = i32::try_from(metadata.list_size()).expect("overflow");
if *list_size != expected_size {
Err(ArrowError::InvalidArgumentError(format!(
"FixedShapeTensor list size mismatch, expected {expected_size} (metadata), found {list_size} (data type)"
)))
} else {
Ok(Self {
value_type: field.data_type().clone(),
metadata,
})
}
}
data_type => Err(ArrowError::InvalidArgumentError(format!(
"FixedShapeTensor data type mismatch, expected FixedSizeList with non-nullable field, found {data_type}"
))),
}
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "canonical_extension_types")]
use crate::extension::CanonicalExtensionType;
use crate::{
extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY},
Field,
};
use super::*;
#[test]
fn valid() -> Result<(), ArrowError> {
let fixed_shape_tensor = FixedShapeTensor::try_new(
DataType::Float32,
[100, 200, 500],
Some(vec!["C".to_owned(), "H".to_owned(), "W".to_owned()]),
Some(vec![2, 0, 1]),
)?;
let mut field = Field::new_fixed_size_list(
"",
Field::new("", DataType::Float32, false),
i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"),
false,
);
field.try_with_extension_type(fixed_shape_tensor.clone())?;
assert_eq!(
field.try_extension_type::<FixedShapeTensor>()?,
fixed_shape_tensor
);
#[cfg(feature = "canonical_extension_types")]
assert_eq!(
field.try_canonical_extension_type()?,
CanonicalExtensionType::FixedShapeTensor(fixed_shape_tensor)
);
Ok(())
}
#[test]
#[should_panic(expected = "Field extension type name missing")]
fn missing_name() {
let field =
Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false)
.with_metadata(
[(
EXTENSION_TYPE_METADATA_KEY.to_owned(),
r#"{ "shape": [100, 200, 500], }"#.to_owned(),
)]
.into_iter()
.collect(),
);
field.extension_type::<FixedShapeTensor>();
}
#[test]
#[should_panic(expected = "FixedShapeTensor data type mismatch, expected FixedSizeList")]
fn invalid_type() {
let fixed_shape_tensor =
FixedShapeTensor::try_new(DataType::Int32, [100, 200, 500], None, None).unwrap();
let field = Field::new_fixed_size_list(
"",
Field::new("", DataType::Float32, false),
i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"),
false,
);
field.with_extension_type(fixed_shape_tensor);
}
#[test]
#[should_panic(expected = "FixedShapeTensor extension types requires metadata")]
fn missing_metadata() {
let field =
Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false)
.with_metadata(
[(
EXTENSION_TYPE_NAME_KEY.to_owned(),
FixedShapeTensor::NAME.to_owned(),
)]
.into_iter()
.collect(),
);
field.extension_type::<FixedShapeTensor>();
}
#[test]
#[should_panic(
expected = "FixedShapeTensor metadata deserialization failed: missing field `shape`"
)]
fn invalid_metadata() {
let fixed_shape_tensor =
FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500], None, None).unwrap();
let field = Field::new_fixed_size_list(
"",
Field::new("", DataType::Float32, false),
i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"),
false,
)
.with_metadata(
[
(
EXTENSION_TYPE_NAME_KEY.to_owned(),
FixedShapeTensor::NAME.to_owned(),
),
(
EXTENSION_TYPE_METADATA_KEY.to_owned(),
r#"{ "not-shape": [] }"#.to_owned(),
),
]
.into_iter()
.collect(),
);
field.extension_type::<FixedShapeTensor>();
}
#[test]
#[should_panic(
expected = "FixedShapeTensor dimension names size mismatch, expected 3, found 2"
)]
fn invalid_metadata_dimension_names() {
FixedShapeTensor::try_new(
DataType::Float32,
[100, 200, 500],
Some(vec!["a".to_owned(), "b".to_owned()]),
None,
)
.unwrap();
}
#[test]
#[should_panic(expected = "FixedShapeTensor permutations size mismatch, expected 3, found 2")]
fn invalid_metadata_permutations_len() {
FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500], None, Some(vec![1, 0]))
.unwrap();
}
#[test]
#[should_panic(
expected = "FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: 3"
)]
fn invalid_metadata_permutations_values() {
FixedShapeTensor::try_new(
DataType::Float32,
[100, 200, 500],
None,
Some(vec![4, 3, 2]),
)
.unwrap();
}
}