use serde_core::de::{self, MapAccess, Visitor};
use serde_core::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;
use crate::{ArrowError, DataType, Field, extension::ExtensionType};
#[derive(Debug, Clone, PartialEq)]
pub struct VariableShapeTensor {
value_type: DataType,
dimensions: usize,
metadata: VariableShapeTensorMetadata,
}
impl VariableShapeTensor {
pub fn try_new(
value_type: DataType,
dimensions: usize,
dimension_names: Option<Vec<String>>,
permutations: Option<Vec<usize>>,
uniform_shapes: Option<Vec<Option<i32>>>,
) -> Result<Self, ArrowError> {
VariableShapeTensorMetadata::try_new(
dimensions,
dimension_names,
permutations,
uniform_shapes,
)
.map(|metadata| Self {
value_type,
dimensions,
metadata,
})
}
pub fn value_type(&self) -> &DataType {
&self.value_type
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
pub fn dimension_names(&self) -> Option<&[String]> {
self.metadata.dimension_names()
}
pub fn permutations(&self) -> Option<&[usize]> {
self.metadata.permutations()
}
pub fn uniform_shapes(&self) -> Option<&[Option<i32>]> {
self.metadata.uniform_shapes()
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct VariableShapeTensorMetadata {
dim_names: Option<Vec<String>>,
permutations: Option<Vec<usize>>,
uniform_shape: Option<Vec<Option<i32>>>,
}
impl Serialize for VariableShapeTensorMetadata {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde_core::ser::SerializeStruct;
let mut state = serializer.serialize_struct("VariableShapeTensorMetadata", 3)?;
state.serialize_field("dim_names", &self.dim_names)?;
state.serialize_field("permutations", &self.permutations)?;
state.serialize_field("uniform_shape", &self.uniform_shape)?;
state.end()
}
}
#[derive(Debug)]
enum MetadataField {
DimNames,
Permutations,
UniformShape,
}
struct MetadataFieldVisitor;
impl<'de> Visitor<'de> for MetadataFieldVisitor {
type Value = MetadataField;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("`dim_names`, `permutations`, or `uniform_shape`")
}
fn visit_str<E>(self, value: &str) -> Result<MetadataField, E>
where
E: de::Error,
{
match value {
"dim_names" => Ok(MetadataField::DimNames),
"permutations" => Ok(MetadataField::Permutations),
"uniform_shape" => Ok(MetadataField::UniformShape),
_ => Err(de::Error::unknown_field(
value,
&["dim_names", "permutations", "uniform_shape"],
)),
}
}
}
impl<'de> Deserialize<'de> for MetadataField {
fn deserialize<D>(deserializer: D) -> Result<MetadataField, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_identifier(MetadataFieldVisitor)
}
}
struct VariableShapeTensorMetadataVisitor;
impl<'de> Visitor<'de> for VariableShapeTensorMetadataVisitor {
type Value = VariableShapeTensorMetadata;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct VariableShapeTensorMetadata")
}
fn visit_seq<V>(self, mut seq: V) -> Result<VariableShapeTensorMetadata, V::Error>
where
V: de::SeqAccess<'de>,
{
let dim_names = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(0, &self))?;
let permutations = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(1, &self))?;
let uniform_shape = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(2, &self))?;
Ok(VariableShapeTensorMetadata {
dim_names,
permutations,
uniform_shape,
})
}
fn visit_map<V>(self, mut map: V) -> Result<VariableShapeTensorMetadata, V::Error>
where
V: MapAccess<'de>,
{
let mut dim_names = None;
let mut permutations = None;
let mut uniform_shape = None;
while let Some(key) = map.next_key()? {
match key {
MetadataField::DimNames => {
if dim_names.is_some() {
return Err(de::Error::duplicate_field("dim_names"));
}
dim_names = Some(map.next_value()?);
}
MetadataField::Permutations => {
if permutations.is_some() {
return Err(de::Error::duplicate_field("permutations"));
}
permutations = Some(map.next_value()?);
}
MetadataField::UniformShape => {
if uniform_shape.is_some() {
return Err(de::Error::duplicate_field("uniform_shape"));
}
uniform_shape = Some(map.next_value()?);
}
}
}
Ok(VariableShapeTensorMetadata {
dim_names,
permutations,
uniform_shape,
})
}
}
impl<'de> Deserialize<'de> for VariableShapeTensorMetadata {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_struct(
"VariableShapeTensorMetadata",
&["dim_names", "permutations", "uniform_shape"],
VariableShapeTensorMetadataVisitor,
)
}
}
impl VariableShapeTensorMetadata {
pub fn try_new(
dimensions: usize,
dimension_names: Option<Vec<String>>,
permutations: Option<Vec<usize>>,
uniform_shapes: Option<Vec<Option<i32>>>,
) -> Result<Self, ArrowError> {
let dim_names = dimension_names.map(|dimension_names| {
if dimension_names.len() != dimensions {
Err(ArrowError::InvalidArgumentError(format!(
"VariableShapeTensor dimension names size mismatch, expected {dimensions}, found {}", dimension_names.len()
)))
} else {
Ok(dimension_names)
}
}).transpose()?;
let permutations = permutations
.map(|permutations| {
if permutations.len() != dimensions {
Err(ArrowError::InvalidArgumentError(format!(
"VariableShapeTensor permutations size mismatch, expected {dimensions}, found {}",
permutations.len()
)))
} else {
let mut sorted_permutations = permutations.clone();
sorted_permutations.sort_unstable();
if (0..dimensions).zip(sorted_permutations).any(|(a, b)| a != b) {
Err(ArrowError::InvalidArgumentError(format!(
"VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}"
)))
} else {
Ok(permutations)
}
}
})
.transpose()?;
let uniform_shape = uniform_shapes
.map(|uniform_shapes| {
if uniform_shapes.len() != dimensions {
Err(ArrowError::InvalidArgumentError(format!(
"VariableShapeTensor uniform shapes size mismatch, expected {dimensions}, found {}",
uniform_shapes.len()
)))
} else {
Ok(uniform_shapes)
}
})
.transpose()?;
Ok(Self {
dim_names,
permutations,
uniform_shape,
})
}
pub fn dimension_names(&self) -> Option<&[String]> {
self.dim_names.as_ref().map(AsRef::as_ref)
}
pub fn permutations(&self) -> Option<&[usize]> {
self.permutations.as_ref().map(AsRef::as_ref)
}
pub fn uniform_shapes(&self) -> Option<&[Option<i32>]> {
self.uniform_shape.as_ref().map(AsRef::as_ref)
}
}
impl ExtensionType for VariableShapeTensor {
const NAME: &'static str = "arrow.variable_shape_tensor";
type Metadata = VariableShapeTensorMetadata;
fn metadata(&self) -> &Self::Metadata {
&self.metadata
}
fn serialize_metadata(&self) -> Option<String> {
Some(serde_json::to_string(self.metadata()).expect("metadata serialization"))
}
fn deserialize_metadata(metadata: Option<&str>) -> Result<Self::Metadata, ArrowError> {
metadata.map_or_else(
|| {
Err(ArrowError::InvalidArgumentError(
"VariableShapeTensor extension types requires metadata".to_owned(),
))
},
|value| {
serde_json::from_str(value).map_err(|e| {
ArrowError::InvalidArgumentError(format!(
"VariableShapeTensor metadata deserialization failed: {e}"
))
})
},
)
}
fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> {
let expected = DataType::Struct(
[
Field::new_list(
"data",
Field::new_list_field(self.value_type.clone(), false),
false,
),
Field::new(
"shape",
DataType::new_fixed_size_list(
DataType::Int32,
i32::try_from(self.dimensions()).expect("overflow"),
false,
),
false,
),
]
.into_iter()
.collect(),
);
data_type
.equals_datatype(&expected)
.then_some(())
.ok_or_else(|| {
ArrowError::InvalidArgumentError(format!(
"VariableShapeTensor data type mismatch, expected {expected}, found {data_type}"
))
})
}
fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError> {
match data_type {
DataType::Struct(fields)
if fields.len() == 2
&& matches!(fields.find("data"), Some((0, _)))
&& matches!(fields.find("shape"), Some((1, _))) =>
{
let shape_field = &fields[1];
match shape_field.data_type() {
DataType::FixedSizeList(_, list_size) => {
let dimensions = usize::try_from(*list_size).expect("conversion failed");
let metadata = VariableShapeTensorMetadata::try_new(
dimensions,
metadata.dim_names,
metadata.permutations,
metadata.uniform_shape,
)?;
let data_field = &fields[0];
match data_field.data_type() {
DataType::List(field) => Ok(Self {
value_type: field.data_type().clone(),
dimensions,
metadata,
}),
data_type => Err(ArrowError::InvalidArgumentError(format!(
"VariableShapeTensor data type mismatch, expected List for data field, found {data_type}"
))),
}
}
data_type => Err(ArrowError::InvalidArgumentError(format!(
"VariableShapeTensor data type mismatch, expected FixedSizeList for shape field, found {data_type}"
))),
}
}
data_type => Err(ArrowError::InvalidArgumentError(format!(
"VariableShapeTensor data type mismatch, expected Struct with 2 fields (data and shape), found {data_type}"
))),
}
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "canonical_extension_types")]
use crate::extension::CanonicalExtensionType;
use crate::{
Field,
extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY},
};
use super::*;
#[test]
fn valid() -> Result<(), ArrowError> {
let variable_shape_tensor = VariableShapeTensor::try_new(
DataType::Float32,
3,
Some(vec!["C".to_owned(), "H".to_owned(), "W".to_owned()]),
Some(vec![2, 0, 1]),
Some(vec![Some(400), None, Some(3)]),
)?;
let mut field = Field::new_struct(
"",
vec![
Field::new_list(
"data",
Field::new_list_field(DataType::Float32, false),
false,
),
Field::new_fixed_size_list(
"shape",
Field::new("", DataType::Int32, false),
3,
false,
),
],
false,
);
field.try_with_extension_type(variable_shape_tensor.clone())?;
assert_eq!(
field.try_extension_type::<VariableShapeTensor>()?,
variable_shape_tensor
);
#[cfg(feature = "canonical_extension_types")]
assert_eq!(
field.try_canonical_extension_type()?,
CanonicalExtensionType::VariableShapeTensor(variable_shape_tensor)
);
Ok(())
}
#[test]
#[should_panic(expected = "Extension type name missing")]
fn missing_name() {
let field = Field::new_struct(
"",
vec![
Field::new_list(
"data",
Field::new_list_field(DataType::Float32, false),
false,
),
Field::new_fixed_size_list(
"shape",
Field::new("", DataType::Int32, false),
3,
false,
),
],
false,
)
.with_metadata(
[(EXTENSION_TYPE_METADATA_KEY.to_owned(), "{}".to_owned())]
.into_iter()
.collect(),
);
field.extension_type::<VariableShapeTensor>();
}
#[test]
#[should_panic(expected = "VariableShapeTensor data type mismatch, expected Struct")]
fn invalid_type() {
let variable_shape_tensor =
VariableShapeTensor::try_new(DataType::Int32, 3, None, None, None).unwrap();
let field = Field::new_struct(
"",
vec![
Field::new_list(
"data",
Field::new_list_field(DataType::Float32, false),
false,
),
Field::new_fixed_size_list(
"shape",
Field::new("", DataType::Int32, false),
3,
false,
),
],
false,
);
field.with_extension_type(variable_shape_tensor);
}
#[test]
#[should_panic(expected = "VariableShapeTensor extension types requires metadata")]
fn missing_metadata() {
let field = Field::new_struct(
"",
vec![
Field::new_list(
"data",
Field::new_list_field(DataType::Float32, false),
false,
),
Field::new_fixed_size_list(
"shape",
Field::new("", DataType::Int32, false),
3,
false,
),
],
false,
)
.with_metadata(
[(
EXTENSION_TYPE_NAME_KEY.to_owned(),
VariableShapeTensor::NAME.to_owned(),
)]
.into_iter()
.collect(),
);
field.extension_type::<VariableShapeTensor>();
}
#[test]
#[should_panic(expected = "VariableShapeTensor metadata deserialization failed: invalid type:")]
fn invalid_metadata() {
let field = Field::new_struct(
"",
vec![
Field::new_list(
"data",
Field::new_list_field(DataType::Float32, false),
false,
),
Field::new_fixed_size_list(
"shape",
Field::new("", DataType::Int32, false),
3,
false,
),
],
false,
)
.with_metadata(
[
(
EXTENSION_TYPE_NAME_KEY.to_owned(),
VariableShapeTensor::NAME.to_owned(),
),
(
EXTENSION_TYPE_METADATA_KEY.to_owned(),
r#"{ "dim_names": [1, null, 3, 4] }"#.to_owned(),
),
]
.into_iter()
.collect(),
);
field.extension_type::<VariableShapeTensor>();
}
#[test]
#[should_panic(
expected = "VariableShapeTensor dimension names size mismatch, expected 3, found 2"
)]
fn invalid_metadata_dimension_names() {
VariableShapeTensor::try_new(
DataType::Float32,
3,
Some(vec!["a".to_owned(), "b".to_owned()]),
None,
None,
)
.unwrap();
}
#[test]
#[should_panic(
expected = "VariableShapeTensor permutations size mismatch, expected 3, found 2"
)]
fn invalid_metadata_permutations_len() {
VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![1, 0]), None).unwrap();
}
#[test]
#[should_panic(
expected = "VariableShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: 3"
)]
fn invalid_metadata_permutations_values() {
VariableShapeTensor::try_new(DataType::Float32, 3, None, Some(vec![4, 3, 2]), None)
.unwrap();
}
#[test]
#[should_panic(
expected = "VariableShapeTensor uniform shapes size mismatch, expected 3, found 2"
)]
fn invalid_metadata_uniform_shapes() {
VariableShapeTensor::try_new(DataType::Float32, 3, None, None, Some(vec![None, Some(1)]))
.unwrap();
}
}