use std::fmt::Formatter;
use std::sync::Arc;
use std::{collections::HashMap, fmt::Display};
use indexmap::IndexMap;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use crate::features::ColumnMappingMode;
use crate::utils::require;
use crate::{DeltaResult, Error};
pub type Schema = StructType;
pub type SchemaRef = Arc<StructType>;
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq)]
#[serde(untagged)]
pub enum MetadataValue {
Number(i32),
String(String),
Boolean(bool),
Other(serde_json::Value),
}
impl From<String> for MetadataValue {
fn from(value: String) -> Self {
Self::String(value)
}
}
impl From<&String> for MetadataValue {
fn from(value: &String) -> Self {
Self::String(value.clone())
}
}
impl From<i32> for MetadataValue {
fn from(value: i32) -> Self {
Self::Number(value)
}
}
impl From<bool> for MetadataValue {
fn from(value: bool) -> Self {
Self::Boolean(value)
}
}
#[derive(Debug)]
pub enum ColumnMetadataKey {
ColumnMappingId,
ColumnMappingPhysicalName,
GenerationExpression,
IdentityStart,
IdentityStep,
IdentityHighWaterMark,
IdentityAllowExplicitInsert,
Invariants,
}
impl AsRef<str> for ColumnMetadataKey {
fn as_ref(&self) -> &str {
match self {
Self::ColumnMappingId => "delta.columnMapping.id",
Self::ColumnMappingPhysicalName => "delta.columnMapping.physicalName",
Self::GenerationExpression => "delta.generationExpression",
Self::IdentityAllowExplicitInsert => "delta.identity.allowExplicitInsert",
Self::IdentityHighWaterMark => "delta.identity.highWaterMark",
Self::IdentityStart => "delta.identity.start",
Self::IdentityStep => "delta.identity.step",
Self::Invariants => "delta.invariants",
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq)]
pub struct StructField {
pub name: String,
#[serde(rename = "type")]
pub data_type: DataType,
pub nullable: bool,
pub metadata: HashMap<String, MetadataValue>,
}
impl StructField {
pub fn new(name: impl Into<String>, data_type: impl Into<DataType>, nullable: bool) -> Self {
Self {
name: name.into(),
data_type: data_type.into(),
nullable,
metadata: HashMap::default(),
}
}
pub fn with_metadata(
mut self,
metadata: impl IntoIterator<Item = (impl Into<String>, impl Into<MetadataValue>)>,
) -> Self {
self.metadata = metadata
.into_iter()
.map(|(k, v)| (k.into(), v.into()))
.collect();
self
}
pub fn get_config_value(&self, key: &ColumnMetadataKey) -> Option<&MetadataValue> {
self.metadata.get(key.as_ref())
}
pub fn physical_name(&self, mapping_mode: ColumnMappingMode) -> DeltaResult<&str> {
let physical_name_key = ColumnMetadataKey::ColumnMappingPhysicalName.as_ref();
let name_mapped_name = self.metadata.get(physical_name_key);
match (mapping_mode, name_mapped_name) {
(ColumnMappingMode::None, _) => Ok(self.name.as_str()),
(ColumnMappingMode::Name, Some(MetadataValue::String(name))) => Ok(name),
(ColumnMappingMode::Name, invalid) => Err(Error::generic(format!(
"Missing or invalid {physical_name_key}: {invalid:?}"
))),
(ColumnMappingMode::Id, _) => {
Err(Error::generic("Don't support id column mapping yet"))
}
}
}
pub fn with_name(&self, new_name: impl Into<String>) -> Self {
StructField {
name: new_name.into(),
data_type: self.data_type().clone(),
nullable: self.nullable,
metadata: self.metadata.clone(),
}
}
#[inline]
pub fn name(&self) -> &String {
&self.name
}
#[inline]
pub fn is_nullable(&self) -> bool {
self.nullable
}
#[inline]
pub const fn data_type(&self) -> &DataType {
&self.data_type
}
#[inline]
pub const fn metadata(&self) -> &HashMap<String, MetadataValue> {
&self.metadata
}
}
#[derive(Debug, PartialEq, Clone, Eq)]
pub struct StructType {
pub type_name: String,
pub fields: IndexMap<String, StructField>,
}
impl StructType {
pub fn new(fields: Vec<StructField>) -> Self {
Self {
type_name: "struct".into(),
fields: fields.into_iter().map(|f| (f.name.clone(), f)).collect(),
}
}
pub fn project_as_struct(&self, names: &[impl AsRef<str>]) -> DeltaResult<StructType> {
let fields = names
.iter()
.map(|name| {
self.fields
.get(name.as_ref())
.cloned()
.ok_or_else(|| Error::missing_column(name.as_ref()))
})
.try_collect()?;
Ok(Self::new(fields))
}
pub fn project(&self, names: &[impl AsRef<str>]) -> DeltaResult<SchemaRef> {
let struct_type = self.project_as_struct(names)?;
Ok(Arc::new(struct_type))
}
pub fn field(&self, name: impl AsRef<str>) -> Option<&StructField> {
self.fields.get(name.as_ref())
}
pub fn index_of(&self, name: impl AsRef<str>) -> Option<usize> {
self.fields.get_index_of(name.as_ref())
}
pub fn fields(&self) -> impl Iterator<Item = &StructField> {
self.fields.values()
}
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
struct StructTypeSerDeHelper {
#[serde(rename = "type")]
type_name: String,
fields: Vec<StructField>,
}
impl Serialize for StructType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
StructTypeSerDeHelper {
type_name: self.type_name.clone(),
fields: self.fields.values().cloned().collect(),
}
.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for StructType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
Self: Sized,
{
let helper = StructTypeSerDeHelper::deserialize(deserializer)?;
Ok(Self {
type_name: helper.type_name,
fields: helper
.fields
.into_iter()
.map(|f| (f.name.clone(), f))
.collect(),
})
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq)]
#[serde(rename_all = "camelCase")]
pub struct ArrayType {
#[serde(rename = "type")]
pub type_name: String,
pub element_type: DataType,
pub contains_null: bool,
}
impl ArrayType {
pub fn new(element_type: DataType, contains_null: bool) -> Self {
Self {
type_name: "array".into(),
element_type,
contains_null,
}
}
#[inline]
pub const fn element_type(&self) -> &DataType {
&self.element_type
}
#[inline]
pub const fn contains_null(&self) -> bool {
self.contains_null
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq)]
#[serde(rename_all = "camelCase")]
pub struct MapType {
#[serde(rename = "type")]
pub type_name: String,
pub key_type: DataType,
pub value_type: DataType,
#[serde(default = "default_true")]
pub value_contains_null: bool,
}
impl MapType {
pub fn new(key_type: DataType, value_type: DataType, value_contains_null: bool) -> Self {
Self {
type_name: "map".into(),
key_type,
value_type,
value_contains_null,
}
}
#[inline]
pub const fn key_type(&self) -> &DataType {
&self.key_type
}
#[inline]
pub const fn value_type(&self) -> &DataType {
&self.value_type
}
#[inline]
pub const fn value_contains_null(&self) -> bool {
self.value_contains_null
}
pub fn as_struct_schema(&self, key_name: String, val_name: String) -> Schema {
StructType::new(vec![
StructField::new(key_name, self.key_type.clone(), false),
StructField::new(val_name, self.value_type.clone(), self.value_contains_null),
])
}
}
fn default_true() -> bool {
true
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq)]
#[serde(rename_all = "camelCase")]
pub enum PrimitiveType {
String,
Long,
Integer,
Short,
Byte,
Float,
Double,
Boolean,
Binary,
Date,
Timestamp,
#[serde(rename = "timestamp_ntz")]
TimestampNtz,
#[serde(
serialize_with = "serialize_decimal",
deserialize_with = "deserialize_decimal",
untagged
)]
Decimal(u8, u8),
}
fn serialize_decimal<S: serde::Serializer>(
precision: &u8,
scale: &u8,
serializer: S,
) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&format!("decimal({},{})", precision, scale))
}
fn deserialize_decimal<'de, D>(deserializer: D) -> Result<(u8, u8), D::Error>
where
D: serde::Deserializer<'de>,
{
let str_value = String::deserialize(deserializer)?;
require!(
str_value.starts_with("decimal(") && str_value.ends_with(')'),
serde::de::Error::custom(format!("Invalid decimal: {}", str_value))
);
let mut parts = str_value[8..str_value.len() - 1].split(',');
let precision = parts
.next()
.and_then(|part| part.trim().parse::<u8>().ok())
.ok_or_else(|| {
serde::de::Error::custom(format!("Invalid precision in decimal: {}", str_value))
})?;
let scale = parts
.next()
.and_then(|part| part.trim().parse::<u8>().ok())
.ok_or_else(|| {
serde::de::Error::custom(format!("Invalid scale in decimal: {}", str_value))
})?;
PrimitiveType::check_decimal(precision, scale).map_err(serde::de::Error::custom)?;
Ok((precision, scale))
}
impl Display for PrimitiveType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
PrimitiveType::String => write!(f, "string"),
PrimitiveType::Long => write!(f, "long"),
PrimitiveType::Integer => write!(f, "integer"),
PrimitiveType::Short => write!(f, "short"),
PrimitiveType::Byte => write!(f, "byte"),
PrimitiveType::Float => write!(f, "float"),
PrimitiveType::Double => write!(f, "double"),
PrimitiveType::Boolean => write!(f, "boolean"),
PrimitiveType::Binary => write!(f, "binary"),
PrimitiveType::Date => write!(f, "date"),
PrimitiveType::Timestamp => write!(f, "timestamp"),
PrimitiveType::TimestampNtz => write!(f, "timestamp_ntz"),
PrimitiveType::Decimal(precision, scale) => {
write!(f, "decimal({},{})", precision, scale)
}
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq)]
#[serde(untagged, rename_all = "camelCase")]
pub enum DataType {
Primitive(PrimitiveType),
Array(Box<ArrayType>),
Struct(Box<StructType>),
Map(Box<MapType>),
}
impl From<MapType> for DataType {
fn from(map_type: MapType) -> Self {
DataType::Map(Box::new(map_type))
}
}
impl From<StructType> for DataType {
fn from(struct_type: StructType) -> Self {
DataType::Struct(Box::new(struct_type))
}
}
impl From<ArrayType> for DataType {
fn from(array_type: ArrayType) -> Self {
DataType::Array(Box::new(array_type))
}
}
impl From<SchemaRef> for DataType {
fn from(schema: SchemaRef) -> Self {
Arc::unwrap_or_clone(schema).into()
}
}
impl DataType {
pub const STRING: Self = DataType::Primitive(PrimitiveType::String);
pub const LONG: Self = DataType::Primitive(PrimitiveType::Long);
pub const INTEGER: Self = DataType::Primitive(PrimitiveType::Integer);
pub const SHORT: Self = DataType::Primitive(PrimitiveType::Short);
pub const BYTE: Self = DataType::Primitive(PrimitiveType::Byte);
pub const FLOAT: Self = DataType::Primitive(PrimitiveType::Float);
pub const DOUBLE: Self = DataType::Primitive(PrimitiveType::Double);
pub const BOOLEAN: Self = DataType::Primitive(PrimitiveType::Boolean);
pub const BINARY: Self = DataType::Primitive(PrimitiveType::Binary);
pub const DATE: Self = DataType::Primitive(PrimitiveType::Date);
pub const TIMESTAMP: Self = DataType::Primitive(PrimitiveType::Timestamp);
pub const TIMESTAMP_NTZ: Self = DataType::Primitive(PrimitiveType::TimestampNtz);
pub fn decimal(precision: u8, scale: u8) -> DeltaResult<Self> {
PrimitiveType::check_decimal(precision, scale)?;
Ok(DataType::Primitive(PrimitiveType::Decimal(
precision, scale,
)))
}
pub fn decimal_unchecked(precision: u8, scale: u8) -> Self {
Self::decimal(precision, scale).unwrap()
}
pub fn struct_type(fields: Vec<StructField>) -> Self {
DataType::Struct(Box::new(StructType::new(fields)))
}
pub fn array_type(elements: ArrayType) -> Self {
DataType::Array(Box::new(elements))
}
}
impl Display for DataType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
DataType::Primitive(p) => write!(f, "{}", p),
DataType::Array(a) => write!(f, "array<{}>", a.element_type),
DataType::Struct(s) => {
write!(f, "struct<")?;
for (i, field) in s.fields().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}: {}", field.name, field.data_type)?;
}
write!(f, ">")
}
DataType::Map(m) => write!(f, "map<{}, {}>", m.key_type, m.value_type),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json;
#[test]
fn test_serde_data_types() {
let data = r#"
{
"name": "a",
"type": "integer",
"nullable": false,
"metadata": {}
}
"#;
let field: StructField = serde_json::from_str(data).unwrap();
assert!(matches!(
field.data_type,
DataType::Primitive(PrimitiveType::Integer)
));
let data = r#"
{
"name": "c",
"type": {
"type": "array",
"elementType": "integer",
"containsNull": false
},
"nullable": true,
"metadata": {}
}
"#;
let field: StructField = serde_json::from_str(data).unwrap();
assert!(matches!(field.data_type, DataType::Array(_)));
let data = r#"
{
"name": "e",
"type": {
"type": "array",
"elementType": {
"type": "struct",
"fields": [
{
"name": "d",
"type": "integer",
"nullable": false,
"metadata": {}
}
]
},
"containsNull": true
},
"nullable": true,
"metadata": {}
}
"#;
let field: StructField = serde_json::from_str(data).unwrap();
assert!(matches!(field.data_type, DataType::Array(_)));
match field.data_type {
DataType::Array(array) => assert!(matches!(array.element_type, DataType::Struct(_))),
_ => unreachable!(),
}
let data = r#"
{
"name": "f",
"type": {
"type": "map",
"keyType": "string",
"valueType": "string",
"valueContainsNull": true
},
"nullable": true,
"metadata": {}
}
"#;
let field: StructField = serde_json::from_str(data).unwrap();
assert!(matches!(field.data_type, DataType::Map(_)));
}
#[test]
fn test_roundtrip_decimal() {
let data = r#"
{
"name": "a",
"type": "decimal(10, 2)",
"nullable": false,
"metadata": {}
}
"#;
let field: StructField = serde_json::from_str(data).unwrap();
assert!(matches!(
field.data_type,
DataType::Primitive(PrimitiveType::Decimal(10, 2))
));
let json_str = serde_json::to_string(&field).unwrap();
assert_eq!(
json_str,
r#"{"name":"a","type":"decimal(10,2)","nullable":false,"metadata":{}}"#
);
}
#[test]
fn test_field_metadata() {
let data = r#"
{
"name": "e",
"type": {
"type": "array",
"elementType": {
"type": "struct",
"fields": [
{
"name": "d",
"type": "integer",
"nullable": false,
"metadata": {
"delta.columnMapping.id": 5,
"delta.columnMapping.physicalName": "col-a7f4159c-53be-4cb0-b81a-f7e5240cfc49"
}
}
]
},
"containsNull": true
},
"nullable": true,
"metadata": {
"delta.columnMapping.id": 4,
"delta.columnMapping.physicalName": "col-5f422f40-de70-45b2-88ab-1d5c90e94db1"
}
}
"#;
let field: StructField = serde_json::from_str(data).unwrap();
let col_id = field
.get_config_value(&ColumnMetadataKey::ColumnMappingId)
.unwrap();
assert!(matches!(col_id, MetadataValue::Number(num) if *num == 4));
let physical_name = field
.get_config_value(&ColumnMetadataKey::ColumnMappingPhysicalName)
.unwrap();
assert!(
matches!(physical_name, MetadataValue::String(name) if *name == "col-5f422f40-de70-45b2-88ab-1d5c90e94db1")
);
}
#[test]
fn test_read_schemas() {
let file = std::fs::File::open("./tests/serde/schema.json").unwrap();
let schema: Result<Schema, _> = serde_json::from_reader(file);
assert!(schema.is_ok());
let file = std::fs::File::open("./tests/serde/checkpoint_schema.json").unwrap();
let schema: Result<Schema, _> = serde_json::from_reader(file);
assert!(schema.is_ok())
}
#[test]
fn test_invalid_decimal() {
let data = r#"
{
"name": "a",
"type": "decimal(39, 10)",
"nullable": false,
"metadata": {}
}
"#;
assert!(serde_json::from_str::<StructField>(data).is_err());
let data = r#"
{
"name": "a",
"type": "decimal(10, 39)",
"nullable": false,
"metadata": {}
}
"#;
assert!(serde_json::from_str::<StructField>(data).is_err());
}
}