datafusion_ffi/udaf/
accumulator_args.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use crate::arrow_wrappers::WrappedSchema;
21use abi_stable::{
22    std_types::{RString, RVec},
23    StableAbi,
24};
25use arrow::{datatypes::Schema, ffi::FFI_ArrowSchema};
26use arrow_schema::FieldRef;
27use datafusion::{
28    error::DataFusionError,
29    logical_expr::function::AccumulatorArgs,
30    physical_expr::{PhysicalExpr, PhysicalSortExpr},
31    prelude::SessionContext,
32};
33use datafusion_proto::{
34    physical_plan::{
35        from_proto::{parse_physical_exprs, parse_physical_sort_exprs},
36        to_proto::{serialize_physical_exprs, serialize_physical_sort_exprs},
37        DefaultPhysicalExtensionCodec,
38    },
39    protobuf::PhysicalAggregateExprNode,
40};
41use prost::Message;
42
43/// A stable struct for sharing [`AccumulatorArgs`] across FFI boundaries.
44/// For an explanation of each field, see the corresponding field
45/// defined in [`AccumulatorArgs`].
46#[repr(C)]
47#[derive(Debug, StableAbi)]
48#[allow(non_camel_case_types)]
49pub struct FFI_AccumulatorArgs {
50    return_field: WrappedSchema,
51    schema: WrappedSchema,
52    is_reversed: bool,
53    name: RString,
54    physical_expr_def: RVec<u8>,
55}
56
57impl TryFrom<AccumulatorArgs<'_>> for FFI_AccumulatorArgs {
58    type Error = DataFusionError;
59
60    fn try_from(args: AccumulatorArgs) -> Result<Self, Self::Error> {
61        let return_field =
62            WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?);
63        let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?);
64
65        let codec = DefaultPhysicalExtensionCodec {};
66        let ordering_req =
67            serialize_physical_sort_exprs(args.order_bys.to_owned(), &codec)?;
68
69        let expr = serialize_physical_exprs(args.exprs, &codec)?;
70
71        let physical_expr_def = PhysicalAggregateExprNode {
72            expr,
73            ordering_req,
74            distinct: args.is_distinct,
75            ignore_nulls: args.ignore_nulls,
76            fun_definition: None,
77            aggregate_function: None,
78            human_display: args.name.to_string(),
79        };
80        let physical_expr_def = physical_expr_def.encode_to_vec().into();
81
82        Ok(Self {
83            return_field,
84            schema,
85            is_reversed: args.is_reversed,
86            name: args.name.into(),
87            physical_expr_def,
88        })
89    }
90}
91
92/// This struct mirrors AccumulatorArgs except that it contains owned data.
93/// It is necessary to create this struct so that we can parse the protobuf
94/// data across the FFI boundary and turn it into owned data that
95/// AccumulatorArgs can then reference.
96pub struct ForeignAccumulatorArgs {
97    pub return_field: FieldRef,
98    pub schema: Schema,
99    pub ignore_nulls: bool,
100    pub order_bys: Vec<PhysicalSortExpr>,
101    pub is_reversed: bool,
102    pub name: String,
103    pub is_distinct: bool,
104    pub exprs: Vec<Arc<dyn PhysicalExpr>>,
105}
106
107impl TryFrom<FFI_AccumulatorArgs> for ForeignAccumulatorArgs {
108    type Error = DataFusionError;
109
110    fn try_from(value: FFI_AccumulatorArgs) -> Result<Self, Self::Error> {
111        let proto_def =
112            PhysicalAggregateExprNode::decode(value.physical_expr_def.as_ref())
113                .map_err(|e| DataFusionError::Execution(e.to_string()))?;
114
115        let return_field = Arc::new((&value.return_field.0).try_into()?);
116        let schema = Schema::try_from(&value.schema.0)?;
117
118        let default_ctx = SessionContext::new();
119        let codex = DefaultPhysicalExtensionCodec {};
120
121        let order_bys = parse_physical_sort_exprs(
122            &proto_def.ordering_req,
123            &default_ctx,
124            &schema,
125            &codex,
126        )?;
127
128        let exprs = parse_physical_exprs(&proto_def.expr, &default_ctx, &schema, &codex)?;
129
130        Ok(Self {
131            return_field,
132            schema,
133            ignore_nulls: proto_def.ignore_nulls,
134            order_bys,
135            is_reversed: value.is_reversed,
136            name: value.name.to_string(),
137            is_distinct: proto_def.distinct,
138            exprs,
139        })
140    }
141}
142
143impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> {
144    fn from(value: &'a ForeignAccumulatorArgs) -> Self {
145        Self {
146            return_field: Arc::clone(&value.return_field),
147            schema: &value.schema,
148            ignore_nulls: value.ignore_nulls,
149            order_bys: &value.order_bys,
150            is_reversed: value.is_reversed,
151            name: value.name.as_str(),
152            is_distinct: value.is_distinct,
153            exprs: &value.exprs,
154        }
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs};
161    use arrow::datatypes::{DataType, Field, Schema};
162    use datafusion::{
163        error::Result, logical_expr::function::AccumulatorArgs,
164        physical_expr::PhysicalSortExpr, physical_plan::expressions::col,
165    };
166
167    #[test]
168    fn test_round_trip_accumulator_args() -> Result<()> {
169        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
170        let orig_args = AccumulatorArgs {
171            return_field: Field::new("f", DataType::Float64, true).into(),
172            schema: &schema,
173            ignore_nulls: false,
174            order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
175            is_reversed: false,
176            name: "round_trip",
177            is_distinct: true,
178            exprs: &[col("a", &schema)?],
179        };
180        let orig_str = format!("{orig_args:?}");
181
182        let ffi_args: FFI_AccumulatorArgs = orig_args.try_into()?;
183        let foreign_args: ForeignAccumulatorArgs = ffi_args.try_into()?;
184        let round_trip_args: AccumulatorArgs = (&foreign_args).into();
185
186        let round_trip_str = format!("{round_trip_args:?}");
187
188        // Since AccumulatorArgs doesn't implement Eq, simply compare
189        // the debug strings.
190        assert_eq!(orig_str, round_trip_str);
191
192        Ok(())
193    }
194}