use std::borrow::Cow;
use std::ops::Deref as _;
use tracing::debug;
use crate::expressions::{Expression, Scalar};
use crate::schema::{ArrayType, DataType, MapType, PrimitiveType, StructField, StructType};
use crate::transforms::SchemaTransform;
use crate::DeltaResult;
struct LiteralExpressionTransform<'a, T: Iterator<Item = &'a Scalar>> {
scalars: T,
stack: Vec<Expression>,
error: Result<(), Error>,
}
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Schema error: {0}")]
Schema(String),
#[error("Excess scalar: {0} given for literal expression transform")]
ExcessScalars(Scalar),
#[error("Too few scalars given for literal expression transform")]
InsufficientScalars,
#[error("No Expression was created after performing the transform")]
EmptyStack,
#[error("Unsupported operation: {0}")]
Unsupported(String),
}
pub(crate) fn literal_expression_transform<'a>(
schema: &'a StructType,
scalars: impl IntoIterator<Item = &'a Scalar>,
) -> DeltaResult<Expression> {
let mut transform = LiteralExpressionTransform {
scalars: scalars.into_iter(),
stack: Vec::new(),
error: Ok(()),
};
let _ = transform.transform_struct(schema);
transform.error?;
if let Some(s) = transform.scalars.next() {
return Err(Error::ExcessScalars(s.clone()).into());
}
transform.stack.pop().ok_or(Error::EmptyStack.into())
}
impl<'a, I: Iterator<Item = &'a Scalar>> LiteralExpressionTransform<'a, I> {
fn set_error(&mut self, error: Error) {
if let Err(ref existing_error) = self.error {
debug!("Trying to overwrite an existing error: {existing_error:?} with {error:?}");
} else {
self.error = Err(error);
}
}
}
macro_rules! transform_leaf {
($self:ident, $type_variant:path, $type:ident) => {{
$self.error.as_ref().ok()?;
let Some(scalar) = $self.scalars.next() else {
$self.set_error(Error::InsufficientScalars);
return None;
};
let $type_variant(ref scalar_type) = scalar.data_type() else {
$self.set_error(Error::Schema(format!(
"Mismatched scalar type while creating Expression: expected {}({:?}), got {:?}",
stringify!($type_variant),
$type,
scalar.data_type()
)));
return None;
};
if scalar_type.deref() != $type {
$self.set_error(Error::Schema(format!(
"Mismatched scalar type while creating Expression: expected {:?}, got {:?}",
$type, scalar_type
)));
return None;
}
$self.stack.push(Expression::Literal(scalar.clone()));
None
}};
}
impl<'a, T: Iterator<Item = &'a Scalar>> SchemaTransform<'a> for LiteralExpressionTransform<'a, T> {
fn transform_primitive(
&mut self,
prim_type: &'a PrimitiveType,
) -> Option<Cow<'a, PrimitiveType>> {
transform_leaf!(self, DataType::Primitive, prim_type)
}
fn transform_struct(&mut self, struct_type: &'a StructType) -> Option<Cow<'a, StructType>> {
self.error.as_ref().ok()?;
let mark = self.stack.len();
self.recurse_into_struct(struct_type)?;
let field_exprs = self.stack.split_off(mark);
let fields = struct_type.fields();
if field_exprs.len() != fields.len() {
self.set_error(Error::InsufficientScalars);
return None;
}
let mut found_non_nullable_null = false;
let mut all_null = true;
for (field, expr) in fields.zip(&field_exprs) {
if !matches!(expr, Expression::Literal(Scalar::Null(_))) {
all_null = false;
} else if !field.is_nullable() {
found_non_nullable_null = true;
}
}
let struct_expr = if found_non_nullable_null {
if !all_null {
self.set_error(Error::Schema(
"NULL value for non-nullable struct field with non-NULL siblings".to_string(),
));
return None;
}
Expression::null_literal(struct_type.clone().into())
} else {
Expression::struct_from(field_exprs)
};
self.stack.push(struct_expr);
None
}
fn transform_struct_field(&mut self, field: &'a StructField) -> Option<Cow<'a, StructField>> {
self.error.as_ref().ok()?;
self.recurse_into_struct_field(field);
Some(Cow::Borrowed(field))
}
fn transform_array(&mut self, array_type: &'a ArrayType) -> Option<Cow<'a, ArrayType>> {
transform_leaf!(self, DataType::Array, array_type)
}
fn transform_map(&mut self, map_type: &'a MapType) -> Option<Cow<'a, MapType>> {
transform_leaf!(self, DataType::Map, map_type)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use crate::expressions::{ArrayData, MapData};
use crate::schema::SchemaRef;
use crate::schema::StructType;
use crate::DataType as DeltaDataTypes;
use paste::paste;
use Expression as Expr;
fn assert_single_row_transform(
values: &[Scalar],
schema: SchemaRef,
expected: Result<Expr, ()>,
) {
let transformed = literal_expression_transform(&schema, values);
match expected {
Ok(expected_expr) => assert_eq!(expected_expr, transformed.unwrap()),
Err(()) => assert!(transformed.is_err()),
}
}
#[test]
fn test_create_one_top_level_null() {
let values = &[Scalar::Null(DeltaDataTypes::INTEGER)];
let schema = Arc::new(StructType::new_unchecked([StructField::not_null(
"col_1",
DeltaDataTypes::INTEGER,
)]));
let expected = Expr::null_literal(schema.clone().into());
assert_single_row_transform(values, schema, Ok(expected));
let schema = Arc::new(StructType::new_unchecked([StructField::nullable(
"col_1",
DeltaDataTypes::INTEGER,
)]));
let expected = Expr::struct_from(vec![Expr::null_literal(DeltaDataTypes::INTEGER)]);
assert_single_row_transform(values, schema, Ok(expected));
}
#[test]
fn test_create_one_missing_values() {
let values = &[1.into()];
let schema = Arc::new(StructType::new_unchecked([
StructField::nullable("col_1", DeltaDataTypes::INTEGER),
StructField::nullable("col_2", DeltaDataTypes::INTEGER),
]));
assert_single_row_transform(values, schema, Err(()));
}
#[test]
fn test_create_one_extra_values() {
let values = &[1.into(), 2.into(), 3.into()];
let schema = Arc::new(StructType::new_unchecked([
StructField::nullable("col_1", DeltaDataTypes::INTEGER),
StructField::nullable("col_2", DeltaDataTypes::INTEGER),
]));
assert_single_row_transform(values, schema, Err(()));
}
#[test]
fn test_create_one_incorrect_schema() {
let values = &["a".into()];
let schema = Arc::new(StructType::new_unchecked([StructField::nullable(
"col_1",
DeltaDataTypes::INTEGER,
)]));
assert_single_row_transform(values, schema, Err(()));
}
#[test]
fn test_many_structs() {
let values: &[Scalar] = &[1.into(), 2.into(), 3.into(), 4.into()];
let schema = Arc::new(StructType::new_unchecked([
StructField::nullable(
"x",
DeltaDataTypes::struct_type_unchecked([
StructField::not_null("a", DeltaDataTypes::INTEGER),
StructField::nullable("b", DeltaDataTypes::INTEGER),
]),
),
StructField::nullable(
"y",
DeltaDataTypes::struct_type_unchecked([
StructField::not_null("c", DeltaDataTypes::INTEGER),
StructField::nullable("d", DeltaDataTypes::INTEGER),
]),
),
]));
let expected = Expr::struct_from(vec![
Expr::struct_from(vec![Expr::literal(1), Expr::literal(2)]),
Expr::struct_from(vec![Expr::literal(3), Expr::literal(4)]),
]);
assert_single_row_transform(values, schema, Ok(expected));
}
#[test]
fn test_map_and_array() {
let map_type = MapType::new(DeltaDataTypes::STRING, DeltaDataTypes::STRING, false);
let map_data = MapData::try_new(map_type.clone(), vec![("k1", "v1")]).unwrap();
let array_type = ArrayType::new(DeltaDataTypes::INTEGER, false);
let array_data = ArrayData::try_new(array_type.clone(), vec![1, 2]).unwrap();
let values: &[Scalar] = &[
Scalar::Map(map_data.clone()),
Scalar::Array(array_data.clone()),
];
let schema = Arc::new(StructType::new_unchecked([
StructField::nullable("map", DeltaDataTypes::Map(Box::new(map_type))),
StructField::nullable("array", DeltaDataTypes::Array(Box::new(array_type))),
]));
let expected = Expr::struct_from(vec![
Expr::literal(Scalar::Map(map_data)),
Expr::literal(Scalar::Array(array_data)),
]);
assert_single_row_transform(values, schema, Ok(expected));
}
#[derive(Clone, Copy)]
struct TestSchema {
x_nullable: bool,
a_nullable: bool,
b_nullable: bool,
}
enum Expected {
Noop,
NullStruct,
Null,
Error, }
fn run_test(test_schema: TestSchema, values: (Option<i32>, Option<i32>), expected: Expected) {
let (a_val, b_val) = values;
let a = match a_val {
Some(v) => Scalar::Integer(v),
None => Scalar::Null(DeltaDataTypes::INTEGER),
};
let b = match b_val {
Some(v) => Scalar::Integer(v),
None => Scalar::Null(DeltaDataTypes::INTEGER),
};
let values: &[Scalar] = &[a, b];
let field_a = StructField::new("a", DeltaDataTypes::INTEGER, test_schema.a_nullable);
let field_b = StructField::new("b", DeltaDataTypes::INTEGER, test_schema.b_nullable);
let field_x = StructField::new(
"x",
StructType::new_unchecked([field_a.clone(), field_b.clone()]),
test_schema.x_nullable,
);
let schema = Arc::new(StructType::new_unchecked([field_x.clone()]));
let expected_result = match expected {
Expected::Noop => {
let nested_struct = Expr::struct_from(vec![
Expr::literal(values[0].clone()),
Expr::literal(values[1].clone()),
]);
Ok(Expr::struct_from([nested_struct]))
}
Expected::Null => Ok(Expr::null_literal(schema.clone().into())),
Expected::NullStruct => {
let nested_null = Expr::null_literal(field_x.data_type().clone());
Ok(Expr::struct_from([nested_null]))
}
Expected::Error => Err(()),
};
assert_single_row_transform(values, schema, expected_result);
}
macro_rules! bool_from_nullable {
(nullable) => {
true
};
(not_null) => {
false
};
}
macro_rules! parse_value {
(a) => {
Some(1)
};
(b) => {
Some(2)
};
(N) => {
None
};
}
macro_rules! test_nullability_combinations {
(
name = $name:ident,
schema = { x: $x:ident, a: $a:ident, b: $b:ident },
tests = {
($ta1:tt, $tb1:tt) -> $expected1:ident,
($ta2:tt, $tb2:tt) -> $expected2:ident,
($ta3:tt, $tb3:tt) -> $expected3:ident,
($ta4:tt, $tb4:tt) -> $expected4:ident $(,)?
}
) => {
paste! {
#[test]
fn [<$name _ $ta1:lower _ $tb1:lower>]() {
let schema = TestSchema {
x_nullable: bool_from_nullable!($x),
a_nullable: bool_from_nullable!($a),
b_nullable: bool_from_nullable!($b),
};
run_test(schema, (parse_value!($ta1), parse_value!($tb1)), Expected::$expected1);
}
#[test]
fn [<$name _ $ta2:lower _ $tb2:lower>]() {
let schema = TestSchema {
x_nullable: bool_from_nullable!($x),
a_nullable: bool_from_nullable!($a),
b_nullable: bool_from_nullable!($b),
};
run_test(schema, (parse_value!($ta2), parse_value!($tb2)), Expected::$expected2);
}
#[test]
fn [<$name _ $ta3:lower _ $tb3:lower>]() {
let schema = TestSchema {
x_nullable: bool_from_nullable!($x),
a_nullable: bool_from_nullable!($a),
b_nullable: bool_from_nullable!($b),
};
run_test(schema, (parse_value!($ta3), parse_value!($tb3)), Expected::$expected3);
}
#[test]
fn [<$name _ $ta4:lower _ $tb4:lower>]() {
let schema = TestSchema {
x_nullable: bool_from_nullable!($x),
a_nullable: bool_from_nullable!($a),
b_nullable: bool_from_nullable!($b),
};
run_test(schema, (parse_value!($ta4), parse_value!($tb4)), Expected::$expected4);
}
}
}
}
test_nullability_combinations! {
name = test_all_nullable,
schema = { x: nullable, a: nullable, b: nullable },
tests = {
(a, b) -> Noop,
(N, b) -> Noop,
(a, N) -> Noop,
(N, N) -> Noop,
}
}
test_nullability_combinations! {
name = test_nullable_nullable_not_null,
schema = { x: nullable, a: nullable, b: not_null },
tests = {
(a, b) -> Noop,
(N, b) -> Noop,
(a, N) -> Error,
(N, N) -> NullStruct,
}
}
test_nullability_combinations! {
name = test_nullable_not_null_not_null,
schema = { x: nullable, a: not_null, b: not_null },
tests = {
(a, b) -> Noop,
(N, b) -> Error,
(a, N) -> Error,
(N, N) -> NullStruct,
}
}
test_nullability_combinations! {
name = test_not_null_nullable_nullable,
schema = { x: not_null, a: nullable, b: nullable },
tests = {
(a, b) -> Noop,
(N, b) -> Noop,
(a, N) -> Noop,
(N, N) -> Noop,
}
}
test_nullability_combinations! {
name = test_not_null_nullable_not_null,
schema = { x: not_null, a: nullable, b: not_null },
tests = {
(a, b) -> Noop,
(N, b) -> Noop,
(a, N) -> Error,
(N, N) -> Null,
}
}
test_nullability_combinations! {
name = test_all_not_null,
schema = { x: not_null, a: not_null, b: not_null },
tests = {
(a, b) -> Noop,
(N, b) -> Error,
(a, N) -> Error,
(N, N) -> Null,
}
}
}