use std::collections::{HashMap, HashSet};
use crate::utils::require;
use super::{DataType, StructField, StructType};
#[allow(unused)]
#[derive(Clone, Copy)]
pub(crate) struct Nullable(bool);
#[allow(unused)]
#[derive(Debug, thiserror::Error)]
pub(crate) enum Error {
#[error("The nullability was tightened for a field")]
NullabilityTightening,
#[error("Field names do not match")]
FieldNameMismatch,
#[error("Schema is invalid")]
InvalidSchema,
#[error("The read schema is missing a column present in the schema")]
MissingColumn,
#[error("Read schema has a non-nullable column that is not present in the schema")]
NewNonNullableColumn,
#[error("Types for two schema fields did not match")]
TypeMismatch,
}
#[allow(unused)]
pub(crate) type SchemaComparisonResult = Result<(), Error>;
#[allow(unused)]
pub(crate) trait SchemaComparison {
fn can_read_as(&self, read_type: &Self) -> SchemaComparisonResult;
}
impl SchemaComparison for Nullable {
fn can_read_as(&self, read_nullable: &Nullable) -> SchemaComparisonResult {
require!(read_nullable.0 || !self.0, Error::NullabilityTightening);
Ok(())
}
}
impl SchemaComparison for StructField {
fn can_read_as(&self, read_field: &Self) -> SchemaComparisonResult {
Nullable(self.nullable).can_read_as(&Nullable(read_field.nullable))?;
require!(self.name() == read_field.name(), Error::FieldNameMismatch);
self.data_type().can_read_as(read_field.data_type())?;
Ok(())
}
}
impl SchemaComparison for StructType {
fn can_read_as(&self, read_type: &Self) -> SchemaComparisonResult {
let lowercase_field_map: HashMap<String, &StructField> = self
.fields
.iter()
.map(|(name, field)| (name.to_lowercase(), field))
.collect();
require!(
lowercase_field_map.len() == self.fields.len(),
Error::InvalidSchema
);
let lowercase_read_field_names: HashSet<String> =
read_type.fields.keys().map(|x| x.to_lowercase()).collect();
require!(
lowercase_read_field_names.len() == read_type.fields.len(),
Error::InvalidSchema
);
if lowercase_field_map
.keys()
.any(|name| !lowercase_read_field_names.contains(name))
{
return Err(Error::MissingColumn);
}
for read_field in read_type.fields() {
match lowercase_field_map.get(&read_field.name().to_lowercase()) {
Some(existing_field) => existing_field.can_read_as(read_field)?,
None => {
require!(read_field.is_nullable(), Error::NewNonNullableColumn);
}
}
}
Ok(())
}
}
impl SchemaComparison for DataType {
fn can_read_as(&self, read_type: &Self) -> SchemaComparisonResult {
match (self, read_type) {
(Self::Array(self_array), Self::Array(read_array)) => {
Nullable(self_array.contains_null())
.can_read_as(&Nullable(read_array.contains_null()))?;
self_array
.element_type()
.can_read_as(read_array.element_type())?;
}
(Self::Struct(self_struct), Self::Struct(read_struct)) => {
self_struct.can_read_as(read_struct)?
}
(Self::Map(self_map), Self::Map(read_map)) => {
Nullable(self_map.value_contains_null())
.can_read_as(&Nullable(read_map.value_contains_null()))?;
self_map.key_type().can_read_as(read_map.key_type())?;
self_map.value_type().can_read_as(read_map.value_type())?;
}
(a, b) if a == b => {}
(Self::Primitive(a), Self::Primitive(b)) if a.can_widen_to(b) => {}
_ => return Err(Error::TypeMismatch),
};
Ok(())
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use crate::schema::compare::{Error, SchemaComparison};
use crate::schema::{ArrayType, DataType, MapType, PrimitiveType, StructField, StructType};
#[test]
fn can_read_is_reflexive() {
let map_key = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("name", DataType::STRING, false),
]);
let map_value =
StructType::new_unchecked([StructField::new("age", DataType::INTEGER, true)]);
let map_type = MapType::new(map_key, map_value, true);
let array_type = ArrayType::new(DataType::TIMESTAMP, false);
let nested_struct = StructType::new_unchecked([
StructField::new("name", DataType::STRING, false),
StructField::new("age", DataType::INTEGER, true),
]);
let schema = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("map", map_type, false),
StructField::new("array", array_type, false),
StructField::new("nested_struct", nested_struct, false),
]);
assert!(schema.can_read_as(&schema).is_ok());
}
#[test]
fn add_nullable_column_to_map_key_and_value() {
let existing_map_key = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("name", DataType::STRING, true),
]);
let existing_map_value =
StructType::new_unchecked([StructField::new("age", DataType::INTEGER, false)]);
let existing_schema = StructType::new_unchecked([StructField::new(
"map",
MapType::new(existing_map_key, existing_map_value, false),
false,
)]);
let read_map_key = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("name", DataType::STRING, true),
StructField::new("location", DataType::STRING, true),
]);
let read_map_value = StructType::new_unchecked([
StructField::new("age", DataType::INTEGER, true),
StructField::new("years_of_experience", DataType::INTEGER, true),
]);
let read_schema = StructType::new_unchecked([StructField::new(
"map",
MapType::new(read_map_key, read_map_value, false),
false,
)]);
assert!(existing_schema.can_read_as(&read_schema).is_ok());
}
#[test]
fn map_value_becomes_non_nullable_fails() {
let map_key = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("name", DataType::STRING, false),
]);
let map_value =
StructType::new_unchecked([StructField::new("age", DataType::INTEGER, true)]);
let existing_schema = StructType::new_unchecked([StructField::new(
"map",
MapType::new(map_key, map_value, false),
false,
)]);
let map_key = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("name", DataType::STRING, false),
]);
let map_value =
StructType::new_unchecked([StructField::new("age", DataType::INTEGER, false)]);
let read_schema = StructType::new_unchecked([StructField::new(
"map",
MapType::new(map_key, map_value, false),
false,
)]);
assert!(matches!(
existing_schema.can_read_as(&read_schema),
Err(Error::NullabilityTightening)
));
}
#[test]
fn different_field_name_case_fails() {
let existing_schema = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("name", DataType::STRING, false),
StructField::new("age", DataType::INTEGER, true),
]);
let read_schema = StructType::new_unchecked([
StructField::new("Id", DataType::LONG, false),
StructField::new("name", DataType::STRING, false),
StructField::new("age", DataType::INTEGER, true),
]);
assert!(matches!(
existing_schema.can_read_as(&read_schema),
Err(Error::FieldNameMismatch)
));
}
#[test]
fn different_type_fails() {
let existing_schema = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("name", DataType::STRING, false),
StructField::new("age", DataType::INTEGER, true),
]);
let read_schema = StructType::new_unchecked([
StructField::new("id", DataType::INTEGER, false),
StructField::new("name", DataType::STRING, false),
StructField::new("age", DataType::INTEGER, true),
]);
assert!(matches!(
existing_schema.can_read_as(&read_schema),
Err(Error::TypeMismatch)
));
}
#[test]
fn set_nullable_to_true() {
let existing_schema = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("name", DataType::STRING, false),
StructField::new("age", DataType::INTEGER, true),
]);
let read_schema = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("name", DataType::STRING, true),
StructField::new("age", DataType::INTEGER, true),
]);
assert!(existing_schema.can_read_as(&read_schema).is_ok());
}
#[test]
fn set_nullable_to_false_fails() {
let existing_schema = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("name", DataType::STRING, false),
StructField::new("age", DataType::INTEGER, true),
]);
let read_schema = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("name", DataType::STRING, false),
StructField::new("age", DataType::INTEGER, false),
]);
assert!(matches!(
existing_schema.can_read_as(&read_schema),
Err(Error::NullabilityTightening)
));
}
#[test]
fn differ_by_nullable_column() {
let a = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("name", DataType::STRING, false),
StructField::new("age", DataType::INTEGER, true),
]);
let b = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("name", DataType::STRING, false),
StructField::new("age", DataType::INTEGER, true),
StructField::new("location", DataType::STRING, true),
]);
assert!(a.can_read_as(&b).is_ok());
assert!(matches!(b.can_read_as(&a), Err(Error::MissingColumn)));
}
#[test]
fn differ_by_non_nullable_column() {
let a = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("name", DataType::STRING, false),
StructField::new("age", DataType::INTEGER, true),
]);
let b = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("name", DataType::STRING, false),
StructField::new("age", DataType::INTEGER, true),
StructField::new("location", DataType::STRING, false),
]);
assert!(matches!(
a.can_read_as(&b),
Err(Error::NewNonNullableColumn)
));
assert!(matches!(b.can_read_as(&a), Err(Error::MissingColumn)));
}
#[test]
fn duplicate_field_modulo_case() {
let existing_schema = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("Id", DataType::LONG, false),
StructField::new("name", DataType::STRING, false),
StructField::new("age", DataType::INTEGER, true),
]);
let read_schema = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("Id", DataType::LONG, false),
StructField::new("name", DataType::STRING, false),
StructField::new("age", DataType::INTEGER, true),
]);
assert!(matches!(
existing_schema.can_read_as(&read_schema),
Err(Error::InvalidSchema)
));
assert!(matches!(
read_schema.can_read_as(&existing_schema),
Err(Error::InvalidSchema)
));
}
#[test]
fn type_widening_integer() {
assert!(DataType::BYTE.can_read_as(&DataType::SHORT).is_ok());
assert!(DataType::BYTE.can_read_as(&DataType::INTEGER).is_ok());
assert!(DataType::BYTE.can_read_as(&DataType::LONG).is_ok());
assert!(DataType::SHORT.can_read_as(&DataType::INTEGER).is_ok());
assert!(DataType::SHORT.can_read_as(&DataType::LONG).is_ok());
assert!(DataType::INTEGER.can_read_as(&DataType::LONG).is_ok());
assert!(matches!(
DataType::LONG.can_read_as(&DataType::INTEGER),
Err(Error::TypeMismatch)
));
assert!(matches!(
DataType::INTEGER.can_read_as(&DataType::SHORT),
Err(Error::TypeMismatch)
));
assert!(matches!(
DataType::SHORT.can_read_as(&DataType::BYTE),
Err(Error::TypeMismatch)
));
}
#[rstest]
#[case::integer_to_date(PrimitiveType::Integer, PrimitiveType::Date, true)]
#[case::long_to_timestamp(PrimitiveType::Long, PrimitiveType::Timestamp, true)]
#[case::long_to_timestamp_ntz(PrimitiveType::Long, PrimitiveType::TimestampNtz, true)]
#[case::date_to_integer(PrimitiveType::Date, PrimitiveType::Integer, false)]
#[case::timestamp_to_long(PrimitiveType::Timestamp, PrimitiveType::Long, false)]
#[case::timestamp_ntz_to_long(PrimitiveType::TimestampNtz, PrimitiveType::Long, false)]
#[case::long_to_date(PrimitiveType::Long, PrimitiveType::Date, false)]
#[case::integer_to_timestamp(PrimitiveType::Integer, PrimitiveType::Timestamp, false)]
#[case::integer_to_timestamp_ntz(PrimitiveType::Integer, PrimitiveType::TimestampNtz, false)]
#[case::byte_to_date(PrimitiveType::Byte, PrimitiveType::Date, false)]
#[case::date_identity(PrimitiveType::Date, PrimitiveType::Date, true)]
#[case::timestamp_identity(PrimitiveType::Timestamp, PrimitiveType::Timestamp, true)]
#[case::long_identity(PrimitiveType::Long, PrimitiveType::Long, true)]
#[case::byte_to_long(PrimitiveType::Byte, PrimitiveType::Long, true)]
#[case::short_to_integer(PrimitiveType::Short, PrimitiveType::Integer, true)]
#[case::float_to_double(PrimitiveType::Float, PrimitiveType::Double, true)]
#[case::timestamp_to_ntz(PrimitiveType::Timestamp, PrimitiveType::TimestampNtz, true)]
fn stats_type_compatibility(
#[case] source: PrimitiveType,
#[case] target: PrimitiveType,
#[case] expected: bool,
) {
assert_eq!(
source.is_stats_type_compatible_with(&target),
expected,
"{source:?} -> {target:?} should be {expected}"
);
}
#[test]
fn type_widening_float() {
assert!(DataType::FLOAT.can_read_as(&DataType::DOUBLE).is_ok());
assert!(matches!(
DataType::DOUBLE.can_read_as(&DataType::FLOAT),
Err(Error::TypeMismatch)
));
}
#[test]
fn type_widening_in_struct() {
let source = StructType::new_unchecked([
StructField::new("id", DataType::INTEGER, false),
StructField::new("value", DataType::FLOAT, true),
]);
let target = StructType::new_unchecked([
StructField::new("id", DataType::LONG, false),
StructField::new("value", DataType::DOUBLE, true),
]);
assert!(source.can_read_as(&target).is_ok());
assert!(matches!(
target.can_read_as(&source),
Err(Error::TypeMismatch)
));
}
#[test]
fn incompatible_type_change() {
assert!(matches!(
DataType::STRING.can_read_as(&DataType::INTEGER),
Err(Error::TypeMismatch)
));
assert!(matches!(
DataType::INTEGER.can_read_as(&DataType::STRING),
Err(Error::TypeMismatch)
));
assert!(matches!(
DataType::BOOLEAN.can_read_as(&DataType::INTEGER),
Err(Error::TypeMismatch)
));
}
}