datafusion_ffi/udaf/
accumulator_args.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use crate::arrow_wrappers::WrappedSchema;
21use abi_stable::{
22    std_types::{RString, RVec},
23    StableAbi,
24};
25use arrow::{datatypes::Schema, ffi::FFI_ArrowSchema};
26use arrow_schema::FieldRef;
27use datafusion::{
28    error::DataFusionError,
29    logical_expr::function::AccumulatorArgs,
30    physical_expr::{PhysicalExpr, PhysicalSortExpr},
31    prelude::SessionContext,
32};
33use datafusion_common::exec_datafusion_err;
34use datafusion_proto::{
35    physical_plan::{
36        from_proto::{parse_physical_exprs, parse_physical_sort_exprs},
37        to_proto::{serialize_physical_exprs, serialize_physical_sort_exprs},
38        DefaultPhysicalExtensionCodec,
39    },
40    protobuf::PhysicalAggregateExprNode,
41};
42use prost::Message;
43
44/// A stable struct for sharing [`AccumulatorArgs`] across FFI boundaries.
45/// For an explanation of each field, see the corresponding field
46/// defined in [`AccumulatorArgs`].
47#[repr(C)]
48#[derive(Debug, StableAbi)]
49#[allow(non_camel_case_types)]
50pub struct FFI_AccumulatorArgs {
51    return_field: WrappedSchema,
52    schema: WrappedSchema,
53    is_reversed: bool,
54    name: RString,
55    physical_expr_def: RVec<u8>,
56}
57
58impl TryFrom<AccumulatorArgs<'_>> for FFI_AccumulatorArgs {
59    type Error = DataFusionError;
60
61    fn try_from(args: AccumulatorArgs) -> Result<Self, Self::Error> {
62        let return_field =
63            WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?);
64        let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?);
65
66        let codec = DefaultPhysicalExtensionCodec {};
67        let ordering_req =
68            serialize_physical_sort_exprs(args.order_bys.to_owned(), &codec)?;
69
70        let expr = serialize_physical_exprs(args.exprs, &codec)?;
71
72        let physical_expr_def = PhysicalAggregateExprNode {
73            expr,
74            ordering_req,
75            distinct: args.is_distinct,
76            ignore_nulls: args.ignore_nulls,
77            fun_definition: None,
78            aggregate_function: None,
79            human_display: args.name.to_string(),
80        };
81        let physical_expr_def = physical_expr_def.encode_to_vec().into();
82
83        Ok(Self {
84            return_field,
85            schema,
86            is_reversed: args.is_reversed,
87            name: args.name.into(),
88            physical_expr_def,
89        })
90    }
91}
92
93/// This struct mirrors AccumulatorArgs except that it contains owned data.
94/// It is necessary to create this struct so that we can parse the protobuf
95/// data across the FFI boundary and turn it into owned data that
96/// AccumulatorArgs can then reference.
97pub struct ForeignAccumulatorArgs {
98    pub return_field: FieldRef,
99    pub schema: Schema,
100    pub expr_fields: Vec<FieldRef>,
101    pub ignore_nulls: bool,
102    pub order_bys: Vec<PhysicalSortExpr>,
103    pub is_reversed: bool,
104    pub name: String,
105    pub is_distinct: bool,
106    pub exprs: Vec<Arc<dyn PhysicalExpr>>,
107}
108
109impl TryFrom<FFI_AccumulatorArgs> for ForeignAccumulatorArgs {
110    type Error = DataFusionError;
111
112    fn try_from(value: FFI_AccumulatorArgs) -> Result<Self, Self::Error> {
113        let proto_def = PhysicalAggregateExprNode::decode(
114            value.physical_expr_def.as_ref(),
115        )
116        .map_err(|e| {
117            exec_datafusion_err!("Failed to decode PhysicalAggregateExprNode: {e}")
118        })?;
119
120        let return_field = Arc::new((&value.return_field.0).try_into()?);
121        let schema = Schema::try_from(&value.schema.0)?;
122
123        let default_ctx = SessionContext::new();
124        let task_ctx = default_ctx.task_ctx();
125        let codex = DefaultPhysicalExtensionCodec {};
126
127        let order_bys = parse_physical_sort_exprs(
128            &proto_def.ordering_req,
129            &task_ctx,
130            &schema,
131            &codex,
132        )?;
133
134        let exprs = parse_physical_exprs(&proto_def.expr, &task_ctx, &schema, &codex)?;
135
136        let expr_fields = exprs
137            .iter()
138            .map(|e| e.return_field(&schema))
139            .collect::<Result<Vec<_>, _>>()?;
140
141        Ok(Self {
142            return_field,
143            schema,
144            expr_fields,
145            ignore_nulls: proto_def.ignore_nulls,
146            order_bys,
147            is_reversed: value.is_reversed,
148            name: value.name.to_string(),
149            is_distinct: proto_def.distinct,
150            exprs,
151        })
152    }
153}
154
155impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> {
156    fn from(value: &'a ForeignAccumulatorArgs) -> Self {
157        Self {
158            return_field: Arc::clone(&value.return_field),
159            schema: &value.schema,
160            expr_fields: &value.expr_fields,
161            ignore_nulls: value.ignore_nulls,
162            order_bys: &value.order_bys,
163            is_reversed: value.is_reversed,
164            name: value.name.as_str(),
165            is_distinct: value.is_distinct,
166            exprs: &value.exprs,
167        }
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs};
174    use arrow::datatypes::{DataType, Field, Schema};
175    use datafusion::{
176        error::Result, logical_expr::function::AccumulatorArgs,
177        physical_expr::PhysicalSortExpr, physical_plan::expressions::col,
178    };
179
180    #[test]
181    fn test_round_trip_accumulator_args() -> Result<()> {
182        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
183        let orig_args = AccumulatorArgs {
184            return_field: Field::new("f", DataType::Float64, true).into(),
185            schema: &schema,
186            expr_fields: &[Field::new("a", DataType::Int32, true).into()],
187            ignore_nulls: false,
188            order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
189            is_reversed: false,
190            name: "round_trip",
191            is_distinct: true,
192            exprs: &[col("a", &schema)?],
193        };
194        let orig_str = format!("{orig_args:?}");
195
196        let ffi_args: FFI_AccumulatorArgs = orig_args.try_into()?;
197        let foreign_args: ForeignAccumulatorArgs = ffi_args.try_into()?;
198        let round_trip_args: AccumulatorArgs = (&foreign_args).into();
199
200        let round_trip_str = format!("{round_trip_args:?}");
201
202        // Since AccumulatorArgs doesn't implement Eq, simply compare
203        // the debug strings.
204        assert_eq!(orig_str, round_trip_str);
205
206        Ok(())
207    }
208}