use std::{collections::HashMap, fmt, ops::Index};
use derive_builder::Builder;
use serde::{
de::{self, Error as SerdeError, IntoDeserializer, MapAccess, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
};
use crate::error::Error;
use super::partition::Transform;
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
#[serde(untagged)]
pub enum Type {
Primitive(PrimitiveType),
Struct(StructType),
List(ListType),
Map(MapType),
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Type::Primitive(primitive) => write!(f, "{}", primitive),
Type::Struct(_) => write!(f, "struct"),
Type::List(_) => write!(f, "list"),
Type::Map(_) => write!(f, "map"),
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
#[serde(rename_all = "lowercase", remote = "Self")]
pub enum PrimitiveType {
Boolean,
Int,
Long,
Float,
Double,
Decimal {
precision: u32,
scale: u32,
},
Date,
Time,
Timestamp,
Timestampz,
String,
Uuid,
Fixed(u64),
Binary,
}
impl<'de> Deserialize<'de> for PrimitiveType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
if s.starts_with("decimal") {
deserialize_decimal(s.into_deserializer())
} else if s.starts_with("fixed") {
deserialize_fixed(s.into_deserializer())
} else {
PrimitiveType::deserialize(s.into_deserializer())
}
}
}
impl Serialize for PrimitiveType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
PrimitiveType::Decimal { precision, scale } => {
serialize_decimal(precision, scale, serializer)
}
PrimitiveType::Fixed(l) => serialize_fixed(l, serializer),
_ => PrimitiveType::serialize(self, serializer),
}
}
}
fn deserialize_decimal<'de, D>(deserializer: D) -> Result<PrimitiveType, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let (precision, scale) = s
.trim_start_matches(r"decimal(")
.trim_end_matches(')')
.split_once(',')
.ok_or_else(|| D::Error::custom("Decimal requires precision and scale: {s}"))?;
Ok(PrimitiveType::Decimal {
precision: precision.parse().map_err(D::Error::custom)?,
scale: scale.parse().map_err(D::Error::custom)?,
})
}
fn serialize_decimal<S>(precision: &u32, scale: &u32, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&format!("decimal({precision},{scale})"))
}
fn deserialize_fixed<'de, D>(deserializer: D) -> Result<PrimitiveType, D::Error>
where
D: Deserializer<'de>,
{
let fixed = String::deserialize(deserializer)?
.trim_start_matches(r"fixed[")
.trim_end_matches(']')
.to_owned();
fixed
.parse()
.map(PrimitiveType::Fixed)
.map_err(D::Error::custom)
}
fn serialize_fixed<S>(value: &u64, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&format!("fixed[{value}]"))
}
impl fmt::Display for PrimitiveType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
PrimitiveType::Boolean => write!(f, "boolean"),
PrimitiveType::Int => write!(f, "int"),
PrimitiveType::Long => write!(f, "long"),
PrimitiveType::Float => write!(f, "float"),
PrimitiveType::Double => write!(f, "double"),
PrimitiveType::Decimal {
precision: _,
scale: _,
} => write!(f, "decimal"),
PrimitiveType::Date => write!(f, "date"),
PrimitiveType::Time => write!(f, "time"),
PrimitiveType::Timestamp => write!(f, "timestamp"),
PrimitiveType::Timestampz => write!(f, "timestampz"),
PrimitiveType::String => write!(f, "string"),
PrimitiveType::Uuid => write!(f, "uuid"),
PrimitiveType::Fixed(_) => write!(f, "fixed"),
PrimitiveType::Binary => write!(f, "binary"),
}
}
}
#[derive(Debug, Serialize, PartialEq, Eq, Clone, Builder)]
#[serde(rename = "struct", tag = "type")]
pub struct StructType {
#[builder(setter(each(name = "with_struct_field")))]
pub fields: Vec<StructField>,
#[serde(skip_serializing)]
#[builder(
default = "self.fields.as_ref().unwrap().iter().enumerate().map(|(idx, field)| (field.id, idx)).collect()"
)]
lookup: HashMap<i32, usize>,
}
impl<'de> Deserialize<'de> for StructType {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "lowercase")]
enum Field {
Type,
Fields,
}
struct StructTypeVisitor;
impl<'de> Visitor<'de> for StructTypeVisitor {
type Value = StructType;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct")
}
fn visit_map<V>(self, mut map: V) -> std::result::Result<StructType, V::Error>
where
V: MapAccess<'de>,
{
let mut fields = None;
while let Some(key) = map.next_key()? {
match key {
Field::Type => (),
Field::Fields => {
if fields.is_some() {
return Err(serde::de::Error::duplicate_field("fields"));
}
fields = Some(map.next_value()?);
}
}
}
let fields: Vec<StructField> =
fields.ok_or_else(|| de::Error::missing_field("fields"))?;
Ok(StructType::new(fields))
}
}
const FIELDS: &[&str] = &["type", "fields"];
deserializer.deserialize_struct("struct", FIELDS, StructTypeVisitor)
}
}
impl StructType {
pub fn new(fields: Vec<StructField>) -> Self {
let lookup = fields
.iter()
.enumerate()
.map(|(idx, field)| (field.id, idx))
.collect();
StructType { fields, lookup }
}
pub fn get(&self, index: usize) -> Option<&StructField> {
self.lookup
.get(&(index as i32))
.map(|idx| &self.fields[*idx])
}
pub fn get_name(&self, name: &str) -> Option<&StructField> {
self.fields.iter().find(|field| field.name == name)
}
}
impl Index<usize> for StructType {
type Output = StructField;
fn index(&self, index: usize) -> &Self::Output {
&self.fields[index]
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
pub struct StructField {
pub id: i32,
pub name: String,
pub required: bool,
#[serde(rename = "type")]
pub field_type: Type,
#[serde(skip_serializing_if = "Option::is_none")]
pub doc: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
#[serde(rename = "list", rename_all = "kebab-case", tag = "type")]
pub struct ListType {
pub element_id: i32,
pub element_required: bool,
pub element: Box<Type>,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
#[serde(rename = "map", rename_all = "kebab-case", tag = "type")]
pub struct MapType {
pub key_id: i32,
pub key: Box<Type>,
pub value_id: i32,
pub value_required: bool,
pub value: Box<Type>,
}
impl Type {
pub fn tranform(&self, transform: &Transform) -> Result<Type, Error> {
match transform {
Transform::Identity => Ok(self.clone()),
Transform::Bucket(_) => Ok(Type::Primitive(PrimitiveType::Int)),
Transform::Truncate(_) => Ok(self.clone()),
Transform::Year => Ok(Type::Primitive(PrimitiveType::Int)),
Transform::Month => Ok(Type::Primitive(PrimitiveType::Int)),
Transform::Day => Ok(Type::Primitive(PrimitiveType::Int)),
Transform::Hour => Ok(Type::Primitive(PrimitiveType::Int)),
Transform::Void => Err(Error::NotSupported("void transform".to_string())),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn check_type_serde(json: &str, expected_type: Type) {
let desered_type: Type = serde_json::from_str(json).unwrap();
assert_eq!(desered_type, expected_type);
let sered_json = serde_json::to_string(&expected_type).unwrap();
let parsed_json_value = serde_json::from_str::<serde_json::Value>(&sered_json).unwrap();
let raw_json_value = serde_json::from_str::<serde_json::Value>(json).unwrap();
assert_eq!(parsed_json_value, raw_json_value);
}
#[test]
fn decimal() {
let record = r#"
{
"type": "struct",
"fields": [
{
"id": 1,
"name": "id",
"required": true,
"type": "decimal(9,2)"
}
]
}
"#;
check_type_serde(
record,
Type::Struct(StructType::new(vec![StructField {
id: 1,
name: "id".to_string(),
field_type: Type::Primitive(PrimitiveType::Decimal {
precision: 9,
scale: 2,
}),
required: true,
doc: None,
}])),
)
}
#[test]
fn fixed() {
let record = r#"
{
"type": "struct",
"fields": [
{
"id": 1,
"name": "id",
"required": true,
"type": "fixed[8]"
}
]
}
"#;
check_type_serde(
record,
Type::Struct(StructType::new(vec![StructField {
id: 1,
name: "id".to_string(),
field_type: Type::Primitive(PrimitiveType::Fixed(8)),
required: true,
doc: None,
}])),
)
}
#[test]
fn struct_type() {
let record = r#"
{
"type": "struct",
"fields": [
{
"id": 1,
"name": "id",
"required": true,
"type": "uuid"
}, {
"id": 2,
"name": "data",
"required": false,
"type": "int"
}
]
}
"#;
check_type_serde(
record,
Type::Struct(StructType::new(vec![
StructField {
id: 1,
name: "id".to_string(),
field_type: Type::Primitive(PrimitiveType::Uuid),
required: true,
doc: None,
},
StructField {
id: 2,
name: "data".to_string(),
field_type: Type::Primitive(PrimitiveType::Int),
required: false,
doc: None,
},
])),
)
}
#[test]
fn list() {
let record = r#"
{
"type": "list",
"element-id": 3,
"element-required": true,
"element": "string"
}
"#;
let result: ListType = serde_json::from_str(record).unwrap();
assert_eq!(Type::Primitive(PrimitiveType::String), *result.element);
}
#[test]
fn map() {
let record = r#"
{
"type": "map",
"key-id": 4,
"key": "string",
"value-id": 5,
"value-required": false,
"value": "double"
}
"#;
let result: MapType = serde_json::from_str(record).unwrap();
assert_eq!(Type::Primitive(PrimitiveType::String), *result.key);
assert_eq!(Type::Primitive(PrimitiveType::Double), *result.value);
}
}