use crate::logical_plan::{
self, AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec,
};
use crate::physical_plan::{
AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec,
};
use crate::protobuf;
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{
create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LogicalPlan, Volatility,
WindowUDF,
};
use prost::{
bytes::{Bytes, BytesMut},
Message,
};
use std::sync::Arc;
use datafusion::execution::registry::FunctionRegistry;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;
mod registry;
pub trait Serializeable: Sized {
fn to_bytes(&self) -> Result<Bytes>;
fn from_bytes(bytes: &[u8]) -> Result<Self> {
Self::from_bytes_with_registry(bytes, ®istry::NoRegistry {})
}
fn from_bytes_with_registry(
bytes: &[u8],
registry: &dyn FunctionRegistry,
) -> Result<Self>;
}
impl Serializeable for Expr {
fn to_bytes(&self) -> Result<Bytes> {
let mut buffer = BytesMut::new();
let protobuf: protobuf::LogicalExprNode = self.try_into().map_err(|e| {
DataFusionError::Plan(format!("Error encoding expr as protobuf: {e}"))
})?;
protobuf.encode(&mut buffer).map_err(|e| {
DataFusionError::Plan(format!("Error encoding protobuf as bytes: {e}"))
})?;
let bytes: Bytes = buffer.into();
struct PlaceHolderRegistry;
impl FunctionRegistry for PlaceHolderRegistry {
fn udfs(&self) -> std::collections::HashSet<String> {
std::collections::HashSet::default()
}
fn udf(&self, name: &str) -> Result<Arc<datafusion_expr::ScalarUDF>> {
Ok(Arc::new(create_udf(
name,
vec![],
Arc::new(arrow::datatypes::DataType::Null),
Volatility::Immutable,
make_scalar_function(|_| unimplemented!()),
)))
}
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
Ok(Arc::new(create_udaf(
name,
vec![arrow::datatypes::DataType::Null],
Arc::new(arrow::datatypes::DataType::Null),
Volatility::Immutable,
Arc::new(|_| unimplemented!()),
Arc::new(vec![]),
)))
}
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
Ok(Arc::new(create_udwf(
name,
arrow::datatypes::DataType::Null,
Arc::new(arrow::datatypes::DataType::Null),
Volatility::Immutable,
Arc::new(|| unimplemented!()),
)))
}
}
Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?;
Ok(bytes)
}
fn from_bytes_with_registry(
bytes: &[u8],
registry: &dyn FunctionRegistry,
) -> Result<Self> {
let protobuf = protobuf::LogicalExprNode::decode(bytes).map_err(|e| {
DataFusionError::Plan(format!("Error decoding expr as protobuf: {e}"))
})?;
logical_plan::from_proto::parse_expr(&protobuf, registry).map_err(|e| {
DataFusionError::Plan(format!("Error parsing protobuf into Expr: {e}"))
})
}
}
pub fn logical_plan_to_bytes(plan: &LogicalPlan) -> Result<Bytes> {
let extension_codec = DefaultLogicalExtensionCodec {};
logical_plan_to_bytes_with_extension_codec(plan, &extension_codec)
}
#[cfg(feature = "json")]
pub fn logical_plan_to_json(plan: &LogicalPlan) -> Result<String> {
let extension_codec = DefaultLogicalExtensionCodec {};
let protobuf =
protobuf::LogicalPlanNode::try_from_logical_plan(plan, &extension_codec)
.map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}")))?;
serde_json::to_string(&protobuf)
.map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}")))
}
pub fn logical_plan_to_bytes_with_extension_codec(
plan: &LogicalPlan,
extension_codec: &dyn LogicalExtensionCodec,
) -> Result<Bytes> {
let protobuf =
protobuf::LogicalPlanNode::try_from_logical_plan(plan, extension_codec)?;
let mut buffer = BytesMut::new();
protobuf.encode(&mut buffer).map_err(|e| {
DataFusionError::Plan(format!("Error encoding protobuf as bytes: {e}"))
})?;
Ok(buffer.into())
}
#[cfg(feature = "json")]
pub fn logical_plan_from_json(json: &str, ctx: &SessionContext) -> Result<LogicalPlan> {
let back: protobuf::LogicalPlanNode = serde_json::from_str(json)
.map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}")))?;
let extension_codec = DefaultLogicalExtensionCodec {};
back.try_into_logical_plan(ctx, &extension_codec)
}
pub fn logical_plan_from_bytes(
bytes: &[u8],
ctx: &SessionContext,
) -> Result<LogicalPlan> {
let extension_codec = DefaultLogicalExtensionCodec {};
logical_plan_from_bytes_with_extension_codec(bytes, ctx, &extension_codec)
}
pub fn logical_plan_from_bytes_with_extension_codec(
bytes: &[u8],
ctx: &SessionContext,
extension_codec: &dyn LogicalExtensionCodec,
) -> Result<LogicalPlan> {
let protobuf = protobuf::LogicalPlanNode::decode(bytes).map_err(|e| {
DataFusionError::Plan(format!("Error decoding expr as protobuf: {e}"))
})?;
protobuf.try_into_logical_plan(ctx, extension_codec)
}
pub fn physical_plan_to_bytes(plan: Arc<dyn ExecutionPlan>) -> Result<Bytes> {
let extension_codec = DefaultPhysicalExtensionCodec {};
physical_plan_to_bytes_with_extension_codec(plan, &extension_codec)
}
#[cfg(feature = "json")]
pub fn physical_plan_to_json(plan: Arc<dyn ExecutionPlan>) -> Result<String> {
let extension_codec = DefaultPhysicalExtensionCodec {};
let protobuf =
protobuf::PhysicalPlanNode::try_from_physical_plan(plan, &extension_codec)
.map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}")))?;
serde_json::to_string(&protobuf)
.map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}")))
}
pub fn physical_plan_to_bytes_with_extension_codec(
plan: Arc<dyn ExecutionPlan>,
extension_codec: &dyn PhysicalExtensionCodec,
) -> Result<Bytes> {
let protobuf =
protobuf::PhysicalPlanNode::try_from_physical_plan(plan, extension_codec)?;
let mut buffer = BytesMut::new();
protobuf.encode(&mut buffer).map_err(|e| {
DataFusionError::Plan(format!("Error encoding protobuf as bytes: {e}"))
})?;
Ok(buffer.into())
}
#[cfg(feature = "json")]
pub fn physical_plan_from_json(
json: &str,
ctx: &SessionContext,
) -> Result<Arc<dyn ExecutionPlan>> {
let back: protobuf::PhysicalPlanNode = serde_json::from_str(json)
.map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}")))?;
let extension_codec = DefaultPhysicalExtensionCodec {};
back.try_into_physical_plan(ctx, &ctx.runtime_env(), &extension_codec)
}
pub fn physical_plan_from_bytes(
bytes: &[u8],
ctx: &SessionContext,
) -> Result<Arc<dyn ExecutionPlan>> {
let extension_codec = DefaultPhysicalExtensionCodec {};
physical_plan_from_bytes_with_extension_codec(bytes, ctx, &extension_codec)
}
pub fn physical_plan_from_bytes_with_extension_codec(
bytes: &[u8],
ctx: &SessionContext,
extension_codec: &dyn PhysicalExtensionCodec,
) -> Result<Arc<dyn ExecutionPlan>> {
let protobuf = protobuf::PhysicalPlanNode::decode(bytes).map_err(|e| {
DataFusionError::Plan(format!("Error decoding expr as protobuf: {e}"))
})?;
protobuf.try_into_physical_plan(ctx, &ctx.runtime_env(), extension_codec)
}
#[cfg(test)]
mod test {
use super::*;
use arrow::{array::ArrayRef, datatypes::DataType};
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::prelude::SessionContext;
use datafusion_expr::{col, create_udf, lit, Volatility};
use std::sync::Arc;
#[test]
#[should_panic(
expected = "Error decoding expr as protobuf: failed to decode Protobuf message"
)]
fn bad_decode() {
Expr::from_bytes(b"Leet").unwrap();
}
#[test]
#[cfg(feature = "json")]
fn plan_to_json() {
use datafusion_common::DFSchema;
use datafusion_expr::logical_plan::EmptyRelation;
let plan = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
});
let actual = logical_plan_to_json(&plan).unwrap();
let expected = r#"{"emptyRelation":{}}"#.to_string();
assert_eq!(actual, expected);
}
#[test]
#[cfg(feature = "json")]
fn json_to_plan() {
let input = r#"{"emptyRelation":{}}"#.to_string();
let ctx = SessionContext::new();
let actual = logical_plan_from_json(&input, &ctx).unwrap();
let result = matches!(actual, LogicalPlan::EmptyRelation(_));
assert!(result, "Should parse empty relation");
}
#[test]
fn udf_roundtrip_with_registry() {
let ctx = context_with_udf();
let expr = ctx
.udf("dummy")
.expect("could not find udf")
.call(vec![lit("")]);
let bytes = expr.to_bytes().unwrap();
let deserialized_expr = Expr::from_bytes_with_registry(&bytes, &ctx).unwrap();
assert_eq!(expr, deserialized_expr);
}
#[test]
#[should_panic(
expected = "No function registry provided to deserialize, so can not deserialize User Defined Function 'dummy'"
)]
fn udf_roundtrip_without_registry() {
let ctx = context_with_udf();
let expr = ctx
.udf("dummy")
.expect("could not find udf")
.call(vec![lit("")]);
let bytes = expr.to_bytes().unwrap();
Expr::from_bytes(&bytes).unwrap();
}
fn roundtrip_expr(expr: &Expr) -> Expr {
let bytes = expr.to_bytes().unwrap();
Expr::from_bytes(&bytes).unwrap()
}
#[test]
fn exact_roundtrip_linearized_binary_expr() {
let expr_ordered = col("A").and(col("B")).and(col("C")).and(col("D"));
assert_eq!(expr_ordered, roundtrip_expr(&expr_ordered));
let other_variants = vec![
col("B").and(col("A")).and(col("C")).and(col("D")),
col("A").and(col("C")).and(col("B")).and(col("D")),
col("A").and(col("B")).and(col("D")).and(col("C")),
col("A").and(col("B").and(col("C").and(col("D")))),
];
for case in other_variants {
assert_eq!(case, roundtrip_expr(&case));
assert_ne!(expr_ordered, roundtrip_expr(&case));
assert_ne!(roundtrip_expr(&expr_ordered), roundtrip_expr(&case));
}
}
#[test]
fn roundtrip_deeply_nested_binary_expr() {
std::thread::Builder::new()
.stack_size(10_000_000)
.spawn(|| {
let n = 100;
let basic_expr = col("a").lt(lit(5i32));
let or_chain = (0..n)
.fold(basic_expr.clone(), |expr, _| expr.or(basic_expr.clone()));
let expr =
(0..n).fold(or_chain.clone(), |expr, _| expr.and(or_chain.clone()));
let bytes = expr.to_bytes().unwrap();
let decoded_expr = Expr::from_bytes(&bytes).expect(
"serialization worked, so deserialization should work as well",
);
assert_eq!(decoded_expr, expr);
})
.expect("spawning thread")
.join()
.expect("joining thread");
}
#[test]
fn roundtrip_deeply_nested_binary_expr_reverse_order() {
std::thread::Builder::new()
.stack_size(10_000_000)
.spawn(|| {
let n = 100;
let expr_base = col("a").lt(lit(5i32));
let and_chain =
(0..n).fold(expr_base.clone(), |expr, _| expr.and(expr_base.clone()));
let expr = expr_base.and(and_chain);
let bytes = expr.to_bytes().unwrap();
let decoded_expr = Expr::from_bytes(&bytes).expect(
"serialization worked, so deserialization should work as well",
);
assert_eq!(decoded_expr, expr);
})
.expect("spawning thread")
.join()
.expect("joining thread");
}
#[test]
fn roundtrip_deeply_nested() {
std::thread::Builder::new().stack_size(20_000_000).spawn(|| {
let n_max = 100;
for n in 1..n_max {
println!("testing: {n}");
let expr_base = col("a").lt(lit(5i32));
let expr = (0..n).fold(expr_base.clone(), |expr, n| if n % 2 == 0 { expr.and(expr_base.clone()) } else { expr.or(expr_base.clone()) });
let bytes = match expr.to_bytes() {
Ok(bytes) => bytes,
Err(_) => {
return;
}
};
let decoded_expr = Expr::from_bytes(&bytes).expect("serialization worked, so deserialization should work as well");
assert_eq!(expr, decoded_expr);
}
panic!("did not find a 'too deeply nested' expression, tested up to a depth of {n_max}")
}).expect("spawning thread").join().expect("joining thread");
}
fn context_with_udf() -> SessionContext {
let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef);
let scalar_fn = make_scalar_function(fn_impl);
let udf = create_udf(
"dummy",
vec![DataType::Utf8],
Arc::new(DataType::Utf8),
Volatility::Immutable,
scalar_fn,
);
let ctx = SessionContext::new();
ctx.register_udf(udf);
ctx
}
}