use std::sync::Arc;
use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field};
use datafusion::execution::FunctionRegistry;
use datafusion::prelude::SessionContext;
use datafusion_expr::expr::Placeholder;
use datafusion_expr::{ColumnarValue, col, create_udf, lit};
use datafusion_expr::{Expr, Volatility};
use datafusion_functions::string;
use datafusion_proto::bytes::Serializeable;
use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec;
use datafusion_proto::logical_plan::to_proto::serialize_expr;
#[test]
#[should_panic(
expected = "Error decoding expr as protobuf: failed to decode Protobuf message"
)]
fn bad_decode() {
Expr::from_bytes(b"Leet").unwrap();
}
#[test]
#[cfg(feature = "json")]
fn plan_to_json() {
use datafusion_common::DFSchema;
use datafusion_expr::{LogicalPlan, logical_plan::EmptyRelation};
use datafusion_proto::bytes::logical_plan_to_json;
let plan = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
});
let actual = logical_plan_to_json(&plan).unwrap();
let expected = r#"{"emptyRelation":{}}"#.to_string();
assert_eq!(actual, expected);
}
#[test]
#[cfg(feature = "json")]
fn json_to_plan() {
use datafusion_expr::LogicalPlan;
use datafusion_proto::bytes::logical_plan_from_json;
let input = r#"{"emptyRelation":{}}"#.to_string();
let ctx = SessionContext::new();
let actual = logical_plan_from_json(&input, &ctx.task_ctx()).unwrap();
let result = matches!(actual, LogicalPlan::EmptyRelation(_));
assert!(result, "Should parse empty relation");
}
#[test]
fn udf_roundtrip_with_registry() {
let ctx = context_with_udf();
let expr = ctx
.udf("dummy")
.expect("could not find udf")
.call(vec![lit("")]);
let bytes = expr.to_bytes().unwrap();
let deserialized_expr = Expr::from_bytes_with_registry(&bytes, &ctx).unwrap();
assert_eq!(expr, deserialized_expr);
}
#[test]
#[should_panic(
expected = "LogicalExtensionCodec is not provided for scalar function dummy"
)]
fn udf_roundtrip_without_registry() {
let ctx = context_with_udf();
let expr = ctx
.udf("dummy")
.expect("could not find udf")
.call(vec![lit("")]);
let bytes = expr.to_bytes().unwrap();
Expr::from_bytes(&bytes).unwrap();
}
fn roundtrip_expr(expr: &Expr) -> Expr {
let bytes = expr.to_bytes().unwrap();
Expr::from_bytes(&bytes).unwrap()
}
#[test]
fn exact_roundtrip_linearized_binary_expr() {
let expr_ordered = col("A").and(col("B")).and(col("C")).and(col("D"));
assert_eq!(expr_ordered, roundtrip_expr(&expr_ordered));
let other_variants = vec![
col("B").and(col("A")).and(col("C")).and(col("D")),
col("A").and(col("C")).and(col("B")).and(col("D")),
col("A").and(col("B")).and(col("D")).and(col("C")),
col("A").and(col("B").and(col("C").and(col("D")))),
];
for case in other_variants {
assert_eq!(case, roundtrip_expr(&case));
assert_ne!(expr_ordered, roundtrip_expr(&case));
assert_ne!(roundtrip_expr(&expr_ordered), roundtrip_expr(&case));
}
}
#[test]
fn roundtrip_qualified_alias() {
let qual_alias = col("c1").alias_qualified(Some("my_table"), "my_column");
assert_eq!(qual_alias, roundtrip_expr(&qual_alias));
}
#[test]
fn roundtrip_placeholder_with_metadata() {
let expr = Expr::Placeholder(Placeholder::new_with_field(
"placeholder_id".to_string(),
Some(
Field::new("", DataType::Utf8, false)
.with_metadata(
[("some_key".to_string(), "some_value".to_string())].into(),
)
.into(),
),
));
assert_eq!(expr, roundtrip_expr(&expr));
}
#[test]
fn roundtrip_deeply_nested_binary_expr() {
std::thread::Builder::new()
.stack_size(10_000_000)
.spawn(|| {
let n = 100;
let basic_expr = col("a").lt(lit(5i32));
let or_chain =
(0..n).fold(basic_expr.clone(), |expr, _| expr.or(basic_expr.clone()));
let expr =
(0..n).fold(or_chain.clone(), |expr, _| expr.and(or_chain.clone()));
let bytes = expr.to_bytes().unwrap();
let decoded_expr = Expr::from_bytes(&bytes)
.expect("serialization worked, so deserialization should work as well");
assert_eq!(decoded_expr, expr);
})
.expect("spawning thread")
.join()
.expect("joining thread");
}
#[test]
fn roundtrip_deeply_nested_binary_expr_reverse_order() {
std::thread::Builder::new()
.stack_size(10_000_000)
.spawn(|| {
let n = 100;
let expr_base = col("a").lt(lit(5i32));
let and_chain =
(0..n).fold(expr_base.clone(), |expr, _| expr.and(expr_base.clone()));
let expr = expr_base.and(and_chain);
let bytes = expr.to_bytes().unwrap();
let decoded_expr = Expr::from_bytes(&bytes)
.expect("serialization worked, so deserialization should work as well");
assert_eq!(decoded_expr, expr);
})
.expect("spawning thread")
.join()
.expect("joining thread");
}
#[test]
fn roundtrip_deeply_nested() {
std::thread::Builder::new().stack_size(20_000_000).spawn(|| {
let n_max = 100;
for n in 1..n_max {
println!("testing: {n}");
let expr_base = col("a").lt(lit(5i32));
let expr = (0..n).fold(expr_base.clone(), |expr, n| if n % 2 == 0 { expr.and(expr_base.clone()) } else { expr.or(expr_base.clone()) });
let bytes = match expr.to_bytes() {
Ok(bytes) => bytes,
Err(_) => {
return;
}
};
let decoded_expr = Expr::from_bytes(&bytes).expect("serialization worked, so deserialization should work as well");
assert_eq!(expr, decoded_expr);
}
panic!("did not find a 'too deeply nested' expression, tested up to a depth of {n_max}")
}).expect("spawning thread").join().expect("joining thread");
}
fn context_with_udf() -> SessionContext {
let scalar_fn = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(array) = &args[0] else {
panic!("should be array")
};
Ok(ColumnarValue::from(Arc::new(array.clone()) as ArrayRef))
});
let udf = create_udf(
"dummy",
vec![DataType::Utf8],
DataType::Utf8,
Volatility::Immutable,
scalar_fn,
);
let ctx = SessionContext::new();
ctx.register_udf(udf);
ctx
}
#[test]
fn test_expression_serialization_roundtrip() {
use datafusion_common::ScalarValue;
use datafusion_expr::expr::ScalarFunction;
use datafusion_proto::logical_plan::from_proto::parse_expr;
let ctx = SessionContext::new();
let lit = Expr::Literal(ScalarValue::Utf8(None), None);
for function in string::functions() {
let num_args = 4;
let args: Vec<_> = std::iter::repeat_n(&lit, num_args).cloned().collect();
let expr = Expr::ScalarFunction(ScalarFunction::new_udf(function, args));
let extension_codec = DefaultLogicalExtensionCodec {};
let proto = serialize_expr(&expr, &extension_codec).unwrap();
let deserialize = parse_expr(&proto, &ctx, &extension_codec).unwrap();
let serialize_name = extract_function_name(&expr);
let deserialize_name = extract_function_name(&deserialize);
assert_eq!(serialize_name, deserialize_name);
}
fn extract_function_name(expr: &Expr) -> String {
let name = expr.schema_name().to_string();
name.split('(').next().unwrap().to_string()
}
}