use std::borrow::{Cow, ToOwned};
use crate::schema::{ArrayType, DataType, MapType, PrimitiveType, StructField, StructType};
use crate::transforms::{
map_owned_children_or_else, map_owned_or_else, map_owned_pair_or_else, transform_output_type,
Carrier,
};
pub trait SchemaTransform<'a> {
type Output<T: ToOwned + ?Sized + 'a>: Carrier<'a, T, Residual = Self::Residual>;
type Residual;
fn transform_primitive(&mut self, ptype: &'a PrimitiveType) -> Self::Output<PrimitiveType> {
Carrier::from_inner(Cow::Borrowed(ptype))
}
fn transform_struct(&mut self, stype: &'a StructType) -> Self::Output<StructType> {
self.recurse_into_struct(stype)
}
fn transform_struct_field(&mut self, field: &'a StructField) -> Self::Output<StructField> {
self.recurse_into_struct_field(field)
}
fn transform_array(&mut self, atype: &'a ArrayType) -> Self::Output<ArrayType> {
self.recurse_into_array(atype)
}
fn transform_array_element(&mut self, etype: &'a DataType) -> Self::Output<DataType> {
self.transform(etype)
}
fn transform_map(&mut self, mtype: &'a MapType) -> Self::Output<MapType> {
self.recurse_into_map(mtype)
}
fn transform_map_key(&mut self, etype: &'a DataType) -> Self::Output<DataType> {
self.transform(etype)
}
fn transform_map_value(&mut self, etype: &'a DataType) -> Self::Output<DataType> {
self.transform(etype)
}
fn transform_variant(&mut self, stype: &'a StructType) -> Self::Output<StructType> {
self.recurse_into_struct(stype)
}
fn transform(&mut self, data_type: &'a DataType) -> Self::Output<DataType> {
match data_type {
DataType::Primitive(ptype) => {
let child = self.transform_primitive(ptype);
map_owned_or_else(data_type, child, DataType::from)
}
DataType::Array(atype) => {
let child = self.transform_array(atype);
map_owned_or_else(data_type, child, DataType::from)
}
DataType::Struct(stype) => {
let child = self.transform_struct(stype);
map_owned_or_else(data_type, child, DataType::from)
}
DataType::Map(mtype) => {
let child = self.transform_map(mtype);
map_owned_or_else(data_type, child, DataType::from)
}
DataType::Variant(stype) => {
let child = self.transform_variant(stype);
map_owned_or_else(data_type, child, |s| DataType::Variant(Box::new(s)))
}
}
}
fn recurse_into_struct_field(&mut self, field: &'a StructField) -> Self::Output<StructField> {
let child = self.transform(&field.data_type);
map_owned_or_else(field, child, |new_data_type| StructField {
name: field.name.clone(),
data_type: new_data_type,
nullable: field.nullable,
metadata: field.metadata.clone(),
})
}
fn recurse_into_struct(&mut self, stype: &'a StructType) -> Self::Output<StructType> {
let children = stype.fields().map(|f| self.transform_struct_field(f));
map_owned_children_or_else(stype, children, StructType::new_unchecked)
}
fn recurse_into_array(&mut self, atype: &'a ArrayType) -> Self::Output<ArrayType> {
let child = self.transform_array_element(&atype.element_type);
map_owned_or_else(atype, child, |element_type| ArrayType {
type_name: atype.type_name.clone(),
element_type,
contains_null: atype.contains_null,
})
}
fn recurse_into_map(&mut self, mtype: &'a MapType) -> Self::Output<MapType> {
let key_type = self.transform_map_key(&mtype.key_type);
let value_type = self.transform_map_value(&mtype.value_type);
let f = |(key_type, value_type)| MapType {
type_name: mtype.type_name.clone(),
key_type,
value_type,
value_contains_null: mtype.value_contains_null,
};
map_owned_pair_or_else(mtype, key_type, value_type, f)
}
}
pub struct SchemaDepthChecker {
depth_limit: usize,
max_depth_seen: usize,
current_depth: usize,
call_count: usize,
}
impl SchemaDepthChecker {
pub fn check(data_type: &DataType, depth_limit: usize) -> usize {
Self::check_with_call_count(data_type, depth_limit).0
}
fn check_with_call_count(data_type: &DataType, depth_limit: usize) -> (usize, usize) {
let mut checker = Self {
depth_limit,
max_depth_seen: 0,
current_depth: 0,
call_count: 0,
};
let _ = checker.transform(data_type);
(checker.max_depth_seen, checker.call_count)
}
fn depth_limited<'a, T: ToOwned + std::fmt::Debug + ?Sized>(
&mut self,
recurse: impl FnOnce(&mut Self, &'a T) -> Result<(), ()>,
arg: &'a T,
) -> Result<(), ()> {
self.call_count += 1;
self.max_depth_seen = self.max_depth_seen.max(self.current_depth);
if self.current_depth > self.depth_limit {
tracing::warn!("Max schema depth {} exceeded by {arg:?}", self.depth_limit);
return Err(());
}
self.current_depth += 1;
let result = recurse(self, arg);
self.current_depth -= 1;
result
}
}
impl<'a> SchemaTransform<'a> for SchemaDepthChecker {
transform_output_type!(|'a, T| Result<(), ()>);
fn transform_struct(&mut self, stype: &'a StructType) -> Result<(), ()> {
self.depth_limited(Self::recurse_into_struct, stype)
}
fn transform_struct_field(&mut self, field: &'a StructField) -> Result<(), ()> {
self.depth_limited(Self::recurse_into_struct_field, field)
}
fn transform_array(&mut self, atype: &'a ArrayType) -> Result<(), ()> {
self.depth_limited(Self::recurse_into_array, atype)
}
fn transform_map(&mut self, mtype: &'a MapType) -> Result<(), ()> {
self.depth_limited(Self::recurse_into_map, mtype)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::{DataType, StructField};
#[test]
fn test_depth_checker() {
let schema = DataType::try_struct_type([
StructField::nullable(
"a",
ArrayType::new(
DataType::try_struct_type([
StructField::nullable("w", DataType::LONG),
StructField::nullable("x", ArrayType::new(DataType::LONG, true)),
StructField::nullable(
"y",
MapType::new(DataType::LONG, DataType::STRING, true),
),
StructField::nullable(
"z",
DataType::try_struct_type([
StructField::nullable("n", DataType::LONG),
StructField::nullable("m", DataType::STRING),
])
.unwrap(),
),
])
.unwrap(),
true,
),
),
StructField::nullable(
"b",
DataType::try_struct_type([
StructField::nullable("o", ArrayType::new(DataType::LONG, true)),
StructField::nullable(
"p",
MapType::new(DataType::LONG, DataType::STRING, true),
),
StructField::nullable(
"q",
DataType::try_struct_type([
StructField::nullable(
"s",
DataType::try_struct_type([
StructField::nullable("u", DataType::LONG),
StructField::nullable("v", DataType::LONG),
])
.unwrap(),
),
StructField::nullable("t", DataType::LONG),
])
.unwrap(),
),
StructField::nullable("r", DataType::LONG),
])
.unwrap(),
),
StructField::nullable(
"c",
MapType::new(
DataType::LONG,
DataType::try_struct_type([
StructField::nullable("f", DataType::LONG),
StructField::nullable("g", DataType::STRING),
])
.unwrap(),
true,
),
),
])
.unwrap();
let check_with_call_count =
|depth_limit| SchemaDepthChecker::check_with_call_count(&schema, depth_limit);
assert_eq!(check_with_call_count(1), (2, 3));
assert_eq!(check_with_call_count(2), (3, 4));
assert_eq!(check_with_call_count(3), (4, 5));
assert_eq!(check_with_call_count(4), (5, 7));
assert_eq!(check_with_call_count(5), (6, 12));
assert_eq!(check_with_call_count(6), (7, 24));
assert_eq!(check_with_call_count(7), (7, 32));
assert_eq!(check_with_call_count(8), (7, 32));
}
}