use std::borrow::Cow;
use crate::schema::{ArrayType, DataType, MapType, PrimitiveType, StructField, StructType};
use crate::transforms::{map_owned_children_or_else, map_owned_or_else, map_owned_pair_or_else};
pub trait SchemaTransform<'a> {
fn transform_primitive(&mut self, ptype: &'a PrimitiveType) -> Option<Cow<'a, PrimitiveType>> {
Some(Cow::Borrowed(ptype))
}
fn transform_struct(&mut self, stype: &'a StructType) -> Option<Cow<'a, StructType>> {
self.recurse_into_struct(stype)
}
fn transform_struct_field(&mut self, field: &'a StructField) -> Option<Cow<'a, StructField>> {
self.recurse_into_struct_field(field)
}
fn transform_array(&mut self, atype: &'a ArrayType) -> Option<Cow<'a, ArrayType>> {
self.recurse_into_array(atype)
}
fn transform_array_element(&mut self, etype: &'a DataType) -> Option<Cow<'a, DataType>> {
self.transform(etype)
}
fn transform_map(&mut self, mtype: &'a MapType) -> Option<Cow<'a, MapType>> {
self.recurse_into_map(mtype)
}
fn transform_map_key(&mut self, etype: &'a DataType) -> Option<Cow<'a, DataType>> {
self.transform(etype)
}
fn transform_map_value(&mut self, etype: &'a DataType) -> Option<Cow<'a, DataType>> {
self.transform(etype)
}
fn transform_variant(&mut self, stype: &'a StructType) -> Option<Cow<'a, StructType>> {
self.recurse_into_struct(stype)
}
fn transform(&mut self, data_type: &'a DataType) -> Option<Cow<'a, DataType>> {
match data_type {
DataType::Primitive(ptype) => {
let child = self.transform_primitive(ptype);
map_owned_or_else(data_type, child, DataType::from)
}
DataType::Array(atype) => {
let child = self.transform_array(atype);
map_owned_or_else(data_type, child, DataType::from)
}
DataType::Struct(stype) => {
let child = self.transform_struct(stype);
map_owned_or_else(data_type, child, DataType::from)
}
DataType::Map(mtype) => {
let child = self.transform_map(mtype);
map_owned_or_else(data_type, child, DataType::from)
}
DataType::Variant(stype) => {
let child = self.transform_variant(stype);
map_owned_or_else(data_type, child, |s| DataType::Variant(Box::new(s)))
}
}
}
fn recurse_into_struct_field(
&mut self,
field: &'a StructField,
) -> Option<Cow<'a, StructField>> {
let child = self.transform(&field.data_type);
map_owned_or_else(field, child, |new_data_type| StructField {
name: field.name.clone(),
data_type: new_data_type,
nullable: field.nullable,
metadata: field.metadata.clone(),
})
}
fn recurse_into_struct(&mut self, stype: &'a StructType) -> Option<Cow<'a, StructType>> {
let children = stype.fields().map(|f| self.transform_struct_field(f));
map_owned_children_or_else(stype, children, StructType::new_unchecked)
}
fn recurse_into_array(&mut self, atype: &'a ArrayType) -> Option<Cow<'a, ArrayType>> {
let child = self.transform_array_element(&atype.element_type);
map_owned_or_else(atype, child, |element_type| ArrayType {
type_name: atype.type_name.clone(),
element_type,
contains_null: atype.contains_null,
})
}
fn recurse_into_map(&mut self, mtype: &'a MapType) -> Option<Cow<'a, MapType>> {
let key_type = self.transform_map_key(&mtype.key_type);
let value_type = self.transform_map_value(&mtype.value_type);
let f = |(key_type, value_type)| MapType {
type_name: mtype.type_name.clone(),
key_type,
value_type,
value_contains_null: mtype.value_contains_null,
};
map_owned_pair_or_else(mtype, key_type, value_type, f)
}
}
pub struct SchemaDepthChecker {
depth_limit: usize,
max_depth_seen: usize,
current_depth: usize,
call_count: usize,
}
impl SchemaDepthChecker {
pub fn check(data_type: &DataType, depth_limit: usize) -> usize {
Self::check_with_call_count(data_type, depth_limit).0
}
fn check_with_call_count(data_type: &DataType, depth_limit: usize) -> (usize, usize) {
let mut checker = Self {
depth_limit,
max_depth_seen: 0,
current_depth: 0,
call_count: 0,
};
let _ = checker.transform(data_type);
(checker.max_depth_seen, checker.call_count)
}
fn depth_limited<'a, T: Clone + std::fmt::Debug>(
&mut self,
recurse: impl FnOnce(&mut Self, &'a T) -> Option<Cow<'a, T>>,
arg: &'a T,
) -> Option<Cow<'a, T>> {
self.call_count += 1;
if self.max_depth_seen < self.current_depth {
self.max_depth_seen = self.current_depth;
if self.depth_limit < self.current_depth {
tracing::warn!("Max schema depth {} exceeded by {arg:?}", self.depth_limit);
}
}
if self.max_depth_seen <= self.depth_limit {
self.current_depth += 1;
let _ = recurse(self, arg);
self.current_depth -= 1;
}
None
}
}
impl<'a> SchemaTransform<'a> for SchemaDepthChecker {
fn transform_struct(&mut self, stype: &'a StructType) -> Option<Cow<'a, StructType>> {
self.depth_limited(Self::recurse_into_struct, stype)
}
fn transform_struct_field(&mut self, field: &'a StructField) -> Option<Cow<'a, StructField>> {
self.depth_limited(Self::recurse_into_struct_field, field)
}
fn transform_array(&mut self, atype: &'a ArrayType) -> Option<Cow<'a, ArrayType>> {
self.depth_limited(Self::recurse_into_array, atype)
}
fn transform_map(&mut self, mtype: &'a MapType) -> Option<Cow<'a, MapType>> {
self.depth_limited(Self::recurse_into_map, mtype)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::{DataType, StructField};
#[test]
fn test_depth_checker() {
let schema = DataType::try_struct_type([
StructField::nullable(
"a",
ArrayType::new(
DataType::try_struct_type([
StructField::nullable("w", DataType::LONG),
StructField::nullable("x", ArrayType::new(DataType::LONG, true)),
StructField::nullable(
"y",
MapType::new(DataType::LONG, DataType::STRING, true),
),
StructField::nullable(
"z",
DataType::try_struct_type([
StructField::nullable("n", DataType::LONG),
StructField::nullable("m", DataType::STRING),
])
.unwrap(),
),
])
.unwrap(),
true,
),
),
StructField::nullable(
"b",
DataType::try_struct_type([
StructField::nullable("o", ArrayType::new(DataType::LONG, true)),
StructField::nullable(
"p",
MapType::new(DataType::LONG, DataType::STRING, true),
),
StructField::nullable(
"q",
DataType::try_struct_type([
StructField::nullable(
"s",
DataType::try_struct_type([
StructField::nullable("u", DataType::LONG),
StructField::nullable("v", DataType::LONG),
])
.unwrap(),
),
StructField::nullable("t", DataType::LONG),
])
.unwrap(),
),
StructField::nullable("r", DataType::LONG),
])
.unwrap(),
),
StructField::nullable(
"c",
MapType::new(
DataType::LONG,
DataType::try_struct_type([
StructField::nullable("f", DataType::LONG),
StructField::nullable("g", DataType::STRING),
])
.unwrap(),
true,
),
),
])
.unwrap();
let check_with_call_count =
|depth_limit| SchemaDepthChecker::check_with_call_count(&schema, depth_limit);
assert_eq!(check_with_call_count(1), (2, 5));
assert_eq!(check_with_call_count(2), (3, 6));
assert_eq!(check_with_call_count(3), (4, 10));
assert_eq!(check_with_call_count(4), (5, 11));
assert_eq!(check_with_call_count(5), (6, 15));
assert_eq!(check_with_call_count(6), (7, 28));
assert_eq!(check_with_call_count(7), (7, 32));
assert_eq!(check_with_call_count(8), (7, 32));
}
}