1use std::{collections::HashMap, fmt, ops::Index, slice::Iter};
19
20use derive_builder::Builder;
21
22use itertools::Itertools;
23use serde::{
24 de::{self, Error as SerdeError, IntoDeserializer, MapAccess, Visitor},
25 Deserialize, Deserializer, Serialize, Serializer,
26};
27
28use crate::error::Error;
29
30use super::partition::Transform;
31
32#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
33#[serde(untagged)]
34pub enum Type {
36 Primitive(PrimitiveType),
38 Struct(StructType),
40 List(ListType),
42 Map(MapType),
44}
45
46impl fmt::Display for Type {
47 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
48 match self {
49 Type::Primitive(primitive) => write!(f, "{primitive}"),
50 Type::Struct(_) => write!(f, "struct"),
51 Type::List(_) => write!(f, "list"),
52 Type::Map(_) => write!(f, "map"),
53 }
54 }
55}
56
57#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
59#[serde(rename_all = "lowercase", remote = "Self")]
60pub enum PrimitiveType {
61 Boolean,
63 Int,
65 Long,
67 Float,
69 Double,
71 Decimal {
73 precision: u32,
75 scale: u32,
77 },
78 Date,
80 Time,
82 Timestamp,
84 Timestamptz,
86 String,
88 Uuid,
90 Fixed(u64),
92 Binary,
94}
95
96impl<'de> Deserialize<'de> for PrimitiveType {
97 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
98 where
99 D: Deserializer<'de>,
100 {
101 let s = String::deserialize(deserializer)?;
102 if s.starts_with("decimal") {
103 deserialize_decimal(s.into_deserializer())
104 } else if s.starts_with("fixed") {
105 deserialize_fixed(s.into_deserializer())
106 } else {
107 PrimitiveType::deserialize(s.into_deserializer())
108 }
109 }
110}
111
112impl Serialize for PrimitiveType {
113 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
114 where
115 S: serde::Serializer,
116 {
117 match self {
118 PrimitiveType::Decimal { precision, scale } => {
119 serialize_decimal(precision, scale, serializer)
120 }
121 PrimitiveType::Fixed(l) => serialize_fixed(l, serializer),
122 _ => PrimitiveType::serialize(self, serializer),
123 }
124 }
125}
126
127fn deserialize_decimal<'de, D>(deserializer: D) -> Result<PrimitiveType, D::Error>
128where
129 D: Deserializer<'de>,
130{
131 let s = String::deserialize(deserializer)?;
132 let (precision, scale) = s
133 .trim_start_matches(r"decimal(")
134 .trim_end_matches(')')
135 .split_once(',')
136 .ok_or_else(|| D::Error::custom("Decimal requires precision and scale: {s}"))?;
137
138 Ok(PrimitiveType::Decimal {
139 precision: precision.parse().map_err(D::Error::custom)?,
140 scale: scale.trim().parse().map_err(D::Error::custom)?,
141 })
142}
143
144fn serialize_decimal<S>(precision: &u32, scale: &u32, serializer: S) -> Result<S::Ok, S::Error>
145where
146 S: Serializer,
147{
148 serializer.serialize_str(&format!("decimal({precision},{scale})"))
149}
150
151fn deserialize_fixed<'de, D>(deserializer: D) -> Result<PrimitiveType, D::Error>
152where
153 D: Deserializer<'de>,
154{
155 let fixed = String::deserialize(deserializer)?
156 .trim_start_matches(r"fixed[")
157 .trim_end_matches(']')
158 .to_owned();
159
160 fixed
161 .parse()
162 .map(PrimitiveType::Fixed)
163 .map_err(D::Error::custom)
164}
165
166fn serialize_fixed<S>(value: &u64, serializer: S) -> Result<S::Ok, S::Error>
167where
168 S: Serializer,
169{
170 serializer.serialize_str(&format!("fixed[{value}]"))
171}
172
173impl fmt::Display for PrimitiveType {
174 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
175 match self {
176 PrimitiveType::Boolean => write!(f, "boolean"),
177 PrimitiveType::Int => write!(f, "int"),
178 PrimitiveType::Long => write!(f, "long"),
179 PrimitiveType::Float => write!(f, "float"),
180 PrimitiveType::Double => write!(f, "double"),
181 PrimitiveType::Decimal {
182 precision: _,
183 scale: _,
184 } => write!(f, "decimal"),
185 PrimitiveType::Date => write!(f, "date"),
186 PrimitiveType::Time => write!(f, "time"),
187 PrimitiveType::Timestamp => write!(f, "timestamp"),
188 PrimitiveType::Timestamptz => write!(f, "timestamptz"),
189 PrimitiveType::String => write!(f, "string"),
190 PrimitiveType::Uuid => write!(f, "uuid"),
191 PrimitiveType::Fixed(_) => write!(f, "fixed"),
192 PrimitiveType::Binary => write!(f, "binary"),
193 }
194 }
195}
196
197#[derive(Debug, Serialize, PartialEq, Eq, Clone, Builder)]
199#[serde(rename = "struct", tag = "type")]
200#[builder(build_fn(error = "Error"))]
201pub struct StructType {
202 #[builder(setter(each(name = "with_struct_field")))]
204 fields: Vec<StructField>,
205 #[serde(skip_serializing)]
207 #[builder(
208 default = "self.fields.as_ref().unwrap().iter().enumerate().map(|(idx, field)| (field.id, idx)).collect()"
209 )]
210 lookup: HashMap<i32, usize>,
211}
212
213impl<'de> Deserialize<'de> for StructType {
214 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
215 where
216 D: Deserializer<'de>,
217 {
218 #[derive(Deserialize)]
219 #[serde(field_identifier, rename_all = "lowercase")]
220 enum Field {
221 Type,
222 Fields,
223 }
224
225 struct StructTypeVisitor;
226
227 impl<'de> Visitor<'de> for StructTypeVisitor {
228 type Value = StructType;
229
230 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
231 formatter.write_str("struct")
232 }
233
234 fn visit_map<V>(self, mut map: V) -> std::result::Result<StructType, V::Error>
235 where
236 V: MapAccess<'de>,
237 {
238 let mut fields = None;
239 while let Some(key) = map.next_key()? {
240 match key {
241 Field::Type => (),
242 Field::Fields => {
243 if fields.is_some() {
244 return Err(serde::de::Error::duplicate_field("fields"));
245 }
246 fields = Some(map.next_value()?);
247 }
248 }
249 }
250 let fields: Vec<StructField> =
251 fields.ok_or_else(|| de::Error::missing_field("fields"))?;
252
253 Ok(StructType::new(fields))
254 }
255 }
256
257 const FIELDS: &[&str] = &["type", "fields"];
258 deserializer.deserialize_struct("struct", FIELDS, StructTypeVisitor)
259 }
260}
261
262impl StructType {
263 pub fn new(fields: Vec<StructField>) -> Self {
271 let lookup = fields
272 .iter()
273 .enumerate()
274 .map(|(idx, field)| (field.id, idx))
275 .collect();
276 StructType { fields, lookup }
277 }
278
279 pub fn builder() -> StructTypeBuilder {
284 StructTypeBuilder::default()
285 }
286
287 #[inline]
296 pub fn get(&self, index: usize) -> Option<&StructField> {
297 self.lookup
298 .get(&(index as i32))
299 .map(|idx| &self.fields[*idx])
300 }
301
302 pub fn get_name(&self, name: &str) -> Option<&StructField> {
311 let res = self.fields.iter().find(|field| field.name == name);
312 if res.is_some() {
313 return res;
314 }
315 let parts: Vec<&str> = name.split('.').collect();
316 let mut current_struct = self;
317 let mut current_field = None;
318
319 for (i, part) in parts.iter().enumerate() {
320 current_field = current_struct
321 .fields
322 .iter()
323 .find(|field| field.name == *part);
324
325 if i == parts.len() - 1 || current_field.is_some() {
326 return current_field;
327 }
328
329 if let Some(field) = current_field {
330 if let Type::Struct(struct_type) = &field.field_type {
331 current_struct = struct_type;
332 } else {
333 return None;
334 }
335 }
336 }
337
338 current_field
339 }
340
341 pub fn len(&self) -> usize {
346 self.fields.len()
347 }
348
349 pub fn is_empty(&self) -> bool {
355 self.fields.is_empty()
356 }
357
358 pub fn iter(&self) -> Iter<'_, StructField> {
363 self.fields.iter()
364 }
365
366 pub fn field_ids(&self) -> impl Iterator<Item = i32> {
371 self.lookup.keys().map(ToOwned::to_owned).sorted()
372 }
373
374 pub fn primitive_field_ids(&self) -> impl Iterator<Item = i32> {
382 self.lookup
383 .iter()
384 .filter(|(_, x)| matches!(self.fields[**x].field_type, Type::Primitive(_)))
385 .map(|x| x.0.to_owned())
386 .sorted()
387 }
388}
389
390impl Index<usize> for StructType {
391 type Output = StructField;
392
393 fn index(&self, index: usize) -> &Self::Output {
394 &self.fields[index]
395 }
396}
397
398#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
399pub struct StructField {
403 pub id: i32,
405 pub name: String,
407 pub required: bool,
409 #[serde(rename = "type")]
411 pub field_type: Type,
412 #[serde(skip_serializing_if = "Option::is_none")]
414 pub doc: Option<String>,
415}
416
417impl StructField {
418 pub fn new(id: i32, name: &str, required: bool, field_type: Type, doc: Option<String>) -> Self {
427 Self {
428 id,
429 name: name.to_owned(),
430 required,
431 field_type,
432 doc,
433 }
434 }
435}
436
437#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
438#[serde(rename = "list", rename_all = "kebab-case", tag = "type")]
439pub struct ListType {
442 pub element_id: i32,
444
445 pub element_required: bool,
447
448 pub element: Box<Type>,
450}
451
452#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
453#[serde(rename = "map", rename_all = "kebab-case", tag = "type")]
454pub struct MapType {
459 pub key_id: i32,
461 pub key: Box<Type>,
463 pub value_id: i32,
465 pub value_required: bool,
467 pub value: Box<Type>,
469}
470
471impl Type {
472 pub fn tranform(&self, transform: &Transform) -> Result<Type, Error> {
474 match transform {
475 Transform::Identity => Ok(self.clone()),
476 Transform::Bucket(_) => Ok(Type::Primitive(PrimitiveType::Int)),
477 Transform::Truncate(_) => Ok(self.clone()),
478 Transform::Year => Ok(Type::Primitive(PrimitiveType::Int)),
479 Transform::Month => Ok(Type::Primitive(PrimitiveType::Int)),
480 Transform::Day => Ok(Type::Primitive(PrimitiveType::Int)),
481 Transform::Hour => Ok(Type::Primitive(PrimitiveType::Int)),
482 Transform::Void => Err(Error::NotSupported("void transform".to_string())),
483 }
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490
491 fn check_type_serde(json: &str, expected_type: Type) {
492 let desered_type: Type = serde_json::from_str(json).unwrap();
493 assert_eq!(desered_type, expected_type);
494
495 let sered_json = serde_json::to_string(&expected_type).unwrap();
496 let parsed_json_value = serde_json::from_str::<serde_json::Value>(&sered_json).unwrap();
497 let raw_json_value = serde_json::from_str::<serde_json::Value>(json).unwrap();
498
499 assert_eq!(parsed_json_value, raw_json_value);
500 }
501
502 #[test]
503 fn decimal() {
504 let record = r#"
505 {
506 "type": "struct",
507 "fields": [
508 {
509 "id": 1,
510 "name": "id",
511 "required": true,
512 "type": "decimal(9,2)"
513 }
514 ]
515 }
516 "#;
517
518 check_type_serde(
519 record,
520 Type::Struct(StructType::new(vec![StructField {
521 id: 1,
522 name: "id".to_string(),
523 field_type: Type::Primitive(PrimitiveType::Decimal {
524 precision: 9,
525 scale: 2,
526 }),
527 required: true,
528 doc: None,
529 }])),
530 )
531 }
532
533 #[test]
534 fn fixed() {
535 let record = r#"
536 {
537 "type": "struct",
538 "fields": [
539 {
540 "id": 1,
541 "name": "id",
542 "required": true,
543 "type": "fixed[8]"
544 }
545 ]
546 }
547 "#;
548
549 check_type_serde(
550 record,
551 Type::Struct(StructType::new(vec![StructField {
552 id: 1,
553 name: "id".to_string(),
554 field_type: Type::Primitive(PrimitiveType::Fixed(8)),
555 required: true,
556 doc: None,
557 }])),
558 )
559 }
560
561 #[test]
562 fn struct_type() {
563 let record = r#"
564 {
565 "type": "struct",
566 "fields": [
567 {
568 "id": 1,
569 "name": "id",
570 "required": true,
571 "type": "uuid"
572 }, {
573 "id": 2,
574 "name": "data",
575 "required": false,
576 "type": "int"
577 }
578 ]
579 }
580 "#;
581
582 check_type_serde(
583 record,
584 Type::Struct(StructType::new(vec![
585 StructField {
586 id: 1,
587 name: "id".to_string(),
588 field_type: Type::Primitive(PrimitiveType::Uuid),
589 required: true,
590 doc: None,
591 },
592 StructField {
593 id: 2,
594 name: "data".to_string(),
595 field_type: Type::Primitive(PrimitiveType::Int),
596 required: false,
597 doc: None,
598 },
599 ])),
600 )
601 }
602
603 #[test]
604 fn list() {
605 let record = r#"
606 {
607 "type": "list",
608 "element-id": 3,
609 "element-required": true,
610 "element": "string"
611 }
612 "#;
613
614 let result: ListType = serde_json::from_str(record).unwrap();
615 assert_eq!(Type::Primitive(PrimitiveType::String), *result.element);
616 }
617
618 #[test]
619 fn map() {
620 let record = r#"
621 {
622 "type": "map",
623 "key-id": 4,
624 "key": "string",
625 "value-id": 5,
626 "value-required": false,
627 "value": "double"
628 }
629 "#;
630
631 let result: MapType = serde_json::from_str(record).unwrap();
632 assert_eq!(Type::Primitive(PrimitiveType::String), *result.key);
633 assert_eq!(Type::Primitive(PrimitiveType::Double), *result.value);
634 }
635}