datafusion-proto 53.0.0

Protobuf serialization of DataFusion logical plan expressions
Documentation
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field};

use datafusion::execution::FunctionRegistry;
use datafusion::prelude::SessionContext;
use datafusion_expr::expr::Placeholder;
use datafusion_expr::{ColumnarValue, col, create_udf, lit};
use datafusion_expr::{Expr, Volatility};
use datafusion_functions::string;
use datafusion_proto::bytes::Serializeable;
use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec;
use datafusion_proto::logical_plan::to_proto::serialize_expr;

#[test]
#[should_panic(
    expected = "Error decoding expr as protobuf: failed to decode Protobuf message"
)]
fn bad_decode() {
    Expr::from_bytes(b"Leet").unwrap();
}

#[test]
#[cfg(feature = "json")]
fn plan_to_json() {
    use datafusion_common::DFSchema;
    use datafusion_expr::{LogicalPlan, logical_plan::EmptyRelation};
    use datafusion_proto::bytes::logical_plan_to_json;

    let plan = LogicalPlan::EmptyRelation(EmptyRelation {
        produce_one_row: false,
        schema: Arc::new(DFSchema::empty()),
    });
    let actual = logical_plan_to_json(&plan).unwrap();
    let expected = r#"{"emptyRelation":{}}"#.to_string();
    assert_eq!(actual, expected);
}

#[test]
#[cfg(feature = "json")]
fn json_to_plan() {
    use datafusion_expr::LogicalPlan;
    use datafusion_proto::bytes::logical_plan_from_json;

    let input = r#"{"emptyRelation":{}}"#.to_string();
    let ctx = SessionContext::new();
    let actual = logical_plan_from_json(&input, &ctx.task_ctx()).unwrap();
    let result = matches!(actual, LogicalPlan::EmptyRelation(_));
    assert!(result, "Should parse empty relation");
}

#[test]
fn udf_roundtrip_with_registry() {
    let ctx = context_with_udf();

    let expr = ctx
        .udf("dummy")
        .expect("could not find udf")
        .call(vec![lit("")]);

    let bytes = expr.to_bytes().unwrap();
    let deserialized_expr = Expr::from_bytes_with_registry(&bytes, &ctx).unwrap();

    assert_eq!(expr, deserialized_expr);
}

#[test]
#[should_panic(
    expected = "LogicalExtensionCodec is not provided for scalar function dummy"
)]
fn udf_roundtrip_without_registry() {
    let ctx = context_with_udf();

    let expr = ctx
        .udf("dummy")
        .expect("could not find udf")
        .call(vec![lit("")]);

    let bytes = expr.to_bytes().unwrap();
    // should explode
    Expr::from_bytes(&bytes).unwrap();
}

fn roundtrip_expr(expr: &Expr) -> Expr {
    let bytes = expr.to_bytes().unwrap();
    Expr::from_bytes(&bytes).unwrap()
}

#[test]
fn exact_roundtrip_linearized_binary_expr() {
    // (((A AND B) AND C) AND D)
    let expr_ordered = col("A").and(col("B")).and(col("C")).and(col("D"));
    assert_eq!(expr_ordered, roundtrip_expr(&expr_ordered));

    // Ensure that no other variation becomes equal
    let other_variants = vec![
        // (((B AND A) AND C) AND D)
        col("B").and(col("A")).and(col("C")).and(col("D")),
        // (((A AND C) AND B) AND D)
        col("A").and(col("C")).and(col("B")).and(col("D")),
        // (((A AND B) AND D) AND C)
        col("A").and(col("B")).and(col("D")).and(col("C")),
        // A AND (B AND (C AND D)))
        col("A").and(col("B").and(col("C").and(col("D")))),
    ];
    for case in other_variants {
        // Each variant is still equal to itself
        assert_eq!(case, roundtrip_expr(&case));

        // But non of them is equal to the original
        assert_ne!(expr_ordered, roundtrip_expr(&case));
        assert_ne!(roundtrip_expr(&expr_ordered), roundtrip_expr(&case));
    }
}

#[test]
fn roundtrip_qualified_alias() {
    let qual_alias = col("c1").alias_qualified(Some("my_table"), "my_column");
    assert_eq!(qual_alias, roundtrip_expr(&qual_alias));
}

#[test]
fn roundtrip_placeholder_with_metadata() {
    let expr = Expr::Placeholder(Placeholder::new_with_field(
        "placeholder_id".to_string(),
        Some(
            Field::new("", DataType::Utf8, false)
                .with_metadata(
                    [("some_key".to_string(), "some_value".to_string())].into(),
                )
                .into(),
        ),
    ));
    assert_eq!(expr, roundtrip_expr(&expr));
}

#[test]
fn roundtrip_deeply_nested_binary_expr() {
    // We need more stack space so this doesn't overflow in dev builds
    std::thread::Builder::new()
        .stack_size(10_000_000)
        .spawn(|| {
            let n = 100;
            // a < 5
            let basic_expr = col("a").lt(lit(5i32));
            // (a < 5) OR (a < 5) OR (a < 5) OR ...
            let or_chain =
                (0..n).fold(basic_expr.clone(), |expr, _| expr.or(basic_expr.clone()));
            // (a < 5) OR (a < 5) AND (a < 5) OR (a < 5) AND (a < 5) AND (a < 5) OR ...
            let expr =
                (0..n).fold(or_chain.clone(), |expr, _| expr.and(or_chain.clone()));

            // Should work fine.
            let bytes = expr.to_bytes().unwrap();

            let decoded_expr = Expr::from_bytes(&bytes)
                .expect("serialization worked, so deserialization should work as well");
            assert_eq!(decoded_expr, expr);
        })
        .expect("spawning thread")
        .join()
        .expect("joining thread");
}

#[test]
fn roundtrip_deeply_nested_binary_expr_reverse_order() {
    // We need more stack space so this doesn't overflow in dev builds
    std::thread::Builder::new()
        .stack_size(10_000_000)
        .spawn(|| {
            let n = 100;

            // a < 5
            let expr_base = col("a").lt(lit(5i32));

            // ((a < 5 AND a < 5) AND a < 5) AND ...
            let and_chain =
                (0..n).fold(expr_base.clone(), |expr, _| expr.and(expr_base.clone()));

            // a < 5 AND (a < 5 AND (a < 5 AND ...))
            let expr = expr_base.and(and_chain);

            // Should work fine.
            let bytes = expr.to_bytes().unwrap();

            let decoded_expr = Expr::from_bytes(&bytes)
                .expect("serialization worked, so deserialization should work as well");
            assert_eq!(decoded_expr, expr);
        })
        .expect("spawning thread")
        .join()
        .expect("joining thread");
}

#[test]
fn roundtrip_deeply_nested() {
    // we need more stack space so this doesn't overflow in dev builds
    std::thread::Builder::new().stack_size(20_000_000).spawn(|| {
            // don't know what "too much" is, so let's slowly try to increase complexity
            let n_max = 100;

            for n in 1..n_max {
                println!("testing: {n}");

                let expr_base = col("a").lt(lit(5i32));
                // Generate a tree of AND and OR expressions (no subsequent ANDs or ORs).
                let expr = (0..n).fold(expr_base.clone(), |expr, n| if n % 2 == 0 { expr.and(expr_base.clone()) } else { expr.or(expr_base.clone()) });

                // Convert it to an opaque form
                let bytes = match expr.to_bytes() {
                    Ok(bytes) => bytes,
                    Err(_) => {
                        // found expression that is too deeply nested
                        return;
                    }
                };

                // Decode bytes from somewhere (over network, etc.
                let decoded_expr = Expr::from_bytes(&bytes).expect("serialization worked, so deserialization should work as well");
                assert_eq!(expr, decoded_expr);
            }

            panic!("did not find a 'too deeply nested' expression, tested up to a depth of {n_max}")
        }).expect("spawning thread").join().expect("joining thread");
}

/// return a `SessionContext` with a `dummy` function registered as a UDF
fn context_with_udf() -> SessionContext {
    let scalar_fn = Arc::new(|args: &[ColumnarValue]| {
        let ColumnarValue::Array(array) = &args[0] else {
            panic!("should be array")
        };
        Ok(ColumnarValue::from(Arc::new(array.clone()) as ArrayRef))
    });

    let udf = create_udf(
        "dummy",
        vec![DataType::Utf8],
        DataType::Utf8,
        Volatility::Immutable,
        scalar_fn,
    );

    let ctx = SessionContext::new();
    ctx.register_udf(udf);

    ctx
}

#[test]
fn test_expression_serialization_roundtrip() {
    use datafusion_common::ScalarValue;
    use datafusion_expr::expr::ScalarFunction;
    use datafusion_proto::logical_plan::from_proto::parse_expr;

    let ctx = SessionContext::new();
    let lit = Expr::Literal(ScalarValue::Utf8(None), None);
    for function in string::functions() {
        // default to 4 args (though some exprs like substr have error checking)
        let num_args = 4;
        let args: Vec<_> = std::iter::repeat_n(&lit, num_args).cloned().collect();
        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(function, args));

        let extension_codec = DefaultLogicalExtensionCodec {};
        let proto = serialize_expr(&expr, &extension_codec).unwrap();
        let deserialize = parse_expr(&proto, &ctx, &extension_codec).unwrap();

        let serialize_name = extract_function_name(&expr);
        let deserialize_name = extract_function_name(&deserialize);

        assert_eq!(serialize_name, deserialize_name);
    }

    /// Extracts the first part of a function name
    /// 'foo(bar)' -> 'foo'
    fn extract_function_name(expr: &Expr) -> String {
        let name = expr.schema_name().to_string();
        name.split('(').next().unwrap().to_string()
    }
}