use std::collections::HashMap;
use marrow::datatypes::{DataType, Field};
use crate::internal::{
error::{fail, Error, Result},
schema::{transmute_field, PrettyField},
};
use super::utils::{check_dim_names, check_permutation, write_list, DebugRepr};
#[derive(Clone, Debug, PartialEq)]
pub struct FixedShapeTensorField {
name: String,
nullable: bool,
element: Field,
shape: Vec<usize>,
dim_names: Option<Vec<String>>,
permutation: Option<Vec<usize>>,
}
impl FixedShapeTensorField {
pub fn new(name: &str, element: impl serde::ser::Serialize, shape: Vec<usize>) -> Result<Self> {
let element = transmute_field(element)?;
if element.name != "element" {
fail!("The element field of FixedShapeTensorField must be named \"element\"");
}
Ok(Self {
name: name.to_owned(),
shape,
element,
nullable: false,
dim_names: None,
permutation: None,
})
}
pub fn nullable(mut self, value: bool) -> Self {
self.nullable = value;
self
}
pub fn permutation(mut self, value: Vec<usize>) -> Result<Self> {
check_permutation(self.shape.len(), &value)?;
self.permutation = Some(value);
Ok(self)
}
pub fn dim_names(mut self, value: Vec<String>) -> Result<Self> {
check_dim_names(self.shape.len(), &value)?;
self.dim_names = Some(value);
Ok(self)
}
}
impl FixedShapeTensorField {
fn get_ext_metadata(&self) -> Result<String> {
use std::fmt::Write;
let mut ext_metadata = String::new();
write!(&mut ext_metadata, "{{")?;
write!(&mut ext_metadata, "\"shape\":")?;
write_list(&mut ext_metadata, self.shape.iter())?;
if let Some(permutation) = self.permutation.as_ref() {
write!(&mut ext_metadata, ",\"permutation\":")?;
write_list(&mut ext_metadata, permutation.iter())?;
}
if let Some(dim_names) = self.dim_names.as_ref() {
write!(&mut ext_metadata, ",\"dim_names\":")?;
write_list(&mut ext_metadata, dim_names.iter().map(DebugRepr))?;
}
write!(&mut ext_metadata, "}}")?;
Ok(ext_metadata)
}
}
impl TryFrom<&FixedShapeTensorField> for Field {
type Error = Error;
fn try_from(value: &FixedShapeTensorField) -> Result<Self> {
let mut n = 1;
for s in &value.shape {
n *= *s;
}
let mut metadata = HashMap::new();
metadata.insert(
"ARROW:extension:name".into(),
"arrow.fixed_shape_tensor".into(),
);
metadata.insert("ARROW:extension:metadata".into(), value.get_ext_metadata()?);
Ok(Field {
name: value.name.to_owned(),
nullable: value.nullable,
data_type: DataType::FixedSizeList(Box::new(value.element.clone()), n.try_into()?),
metadata,
})
}
}
impl serde::ser::Serialize for FixedShapeTensorField {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
use serde::ser::Error;
let field = Field::try_from(self).map_err(S::Error::custom)?;
PrettyField(&field).serialize(serializer)
}
}
#[test]
fn fixed_shape_tensor_field_repr() -> crate::internal::error::PanicOnError<()> {
use serde_json::json;
let field = FixedShapeTensorField::new(
"hello",
json!({"name": "element", "data_type": "F32"}),
vec![2, 3],
)?;
let field = Field::try_from(&field)?;
let actual = serde_json::to_value(PrettyField(&field))?;
let expected = json!({
"name": "hello",
"data_type": "FixedSizeList(6)",
"children": [{
"name": "element",
"data_type": "F32",
}],
"metadata": {
"ARROW:extension:metadata": "{\"shape\":[2,3]}",
"ARROW:extension:name": "arrow.fixed_shape_tensor",
},
});
assert_eq!(actual, expected);
Ok(())
}