datafusion_ffi/udaf/
accumulator_args.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use abi_stable::StableAbi;
21use abi_stable::std_types::{RString, RVec};
22use arrow::datatypes::Schema;
23use arrow::ffi::FFI_ArrowSchema;
24use arrow_schema::FieldRef;
25use datafusion_common::error::DataFusionError;
26use datafusion_expr::function::AccumulatorArgs;
27use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
28
29use crate::arrow_wrappers::WrappedSchema;
30use crate::physical_expr::FFI_PhysicalExpr;
31use crate::physical_expr::sort::FFI_PhysicalSortExpr;
32use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped};
33
34/// A stable struct for sharing [`AccumulatorArgs`] across FFI boundaries.
35/// For an explanation of each field, see the corresponding field
36/// defined in [`AccumulatorArgs`].
37#[repr(C)]
38#[derive(Debug, StableAbi)]
39pub struct FFI_AccumulatorArgs {
40    return_field: WrappedSchema,
41    schema: WrappedSchema,
42    ignore_nulls: bool,
43    order_bys: RVec<FFI_PhysicalSortExpr>,
44    is_reversed: bool,
45    name: RString,
46    is_distinct: bool,
47    exprs: RVec<FFI_PhysicalExpr>,
48    expr_fields: RVec<WrappedSchema>,
49}
50
51impl TryFrom<AccumulatorArgs<'_>> for FFI_AccumulatorArgs {
52    type Error = DataFusionError;
53    fn try_from(args: AccumulatorArgs) -> Result<Self, DataFusionError> {
54        let return_field =
55            WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?);
56        let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?);
57
58        let order_bys: RVec<_> = args
59            .order_bys
60            .iter()
61            .map(FFI_PhysicalSortExpr::from)
62            .collect();
63
64        let exprs = args
65            .exprs
66            .iter()
67            .map(Arc::clone)
68            .map(FFI_PhysicalExpr::from)
69            .collect();
70
71        let expr_fields = vec_fieldref_to_rvec_wrapped(args.expr_fields)?;
72
73        Ok(Self {
74            return_field,
75            schema,
76            ignore_nulls: args.ignore_nulls,
77            order_bys,
78            is_reversed: args.is_reversed,
79            name: args.name.into(),
80            is_distinct: args.is_distinct,
81            exprs,
82            expr_fields,
83        })
84    }
85}
86
87/// This struct mirrors AccumulatorArgs except that it contains owned data.
88/// It is necessary to create this struct so that we can parse the protobuf
89/// data across the FFI boundary and turn it into owned data that
90/// AccumulatorArgs can then reference.
91pub struct ForeignAccumulatorArgs {
92    pub return_field: FieldRef,
93    pub schema: Schema,
94    pub expr_fields: Vec<FieldRef>,
95    pub ignore_nulls: bool,
96    pub order_bys: Vec<PhysicalSortExpr>,
97    pub is_reversed: bool,
98    pub name: String,
99    pub is_distinct: bool,
100    pub exprs: Vec<Arc<dyn PhysicalExpr>>,
101}
102
103impl TryFrom<FFI_AccumulatorArgs> for ForeignAccumulatorArgs {
104    type Error = DataFusionError;
105
106    fn try_from(value: FFI_AccumulatorArgs) -> Result<Self, Self::Error> {
107        let return_field = Arc::new((&value.return_field.0).try_into()?);
108        let schema = Schema::try_from(&value.schema.0)?;
109
110        let order_bys = value.order_bys.iter().map(PhysicalSortExpr::from).collect();
111
112        let exprs = value
113            .exprs
114            .iter()
115            .map(<Arc<dyn PhysicalExpr>>::from)
116            .collect();
117
118        let expr_fields = rvec_wrapped_to_vec_fieldref(&value.expr_fields)?;
119
120        Ok(Self {
121            return_field,
122            schema,
123            expr_fields,
124            ignore_nulls: value.ignore_nulls,
125            order_bys,
126            is_reversed: value.is_reversed,
127            name: value.name.to_string(),
128            is_distinct: value.is_distinct,
129            exprs,
130        })
131    }
132}
133
134impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> {
135    fn from(value: &'a ForeignAccumulatorArgs) -> Self {
136        Self {
137            return_field: Arc::clone(&value.return_field),
138            schema: &value.schema,
139            expr_fields: &value.expr_fields,
140            ignore_nulls: value.ignore_nulls,
141            order_bys: &value.order_bys,
142            is_reversed: value.is_reversed,
143            name: value.name.as_str(),
144            is_distinct: value.is_distinct,
145            exprs: &value.exprs,
146        }
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use arrow::datatypes::{DataType, Field, Schema};
153    use datafusion::error::Result;
154    use datafusion::logical_expr::function::AccumulatorArgs;
155    use datafusion::physical_expr::PhysicalSortExpr;
156    use datafusion::physical_plan::expressions::col;
157
158    use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs};
159
160    #[test]
161    fn test_round_trip_accumulator_args() -> Result<()> {
162        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
163        let orig_args = AccumulatorArgs {
164            return_field: Field::new("f", DataType::Float64, true).into(),
165            schema: &schema,
166            expr_fields: &[Field::new("a", DataType::Int32, true).into()],
167            ignore_nulls: false,
168            order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
169            is_reversed: false,
170            name: "round_trip",
171            is_distinct: true,
172            exprs: &[col("a", &schema)?],
173        };
174        let orig_str = format!("{orig_args:?}");
175
176        let ffi_args = FFI_AccumulatorArgs::try_from(orig_args)?;
177        let foreign_args: ForeignAccumulatorArgs = ffi_args.try_into()?;
178        let round_trip_args: AccumulatorArgs = (&foreign_args).into();
179
180        let round_trip_str = format!("{round_trip_args:?}");
181
182        // Since AccumulatorArgs doesn't implement Eq, simply compare
183        // the debug strings.
184        assert_eq!(orig_str, round_trip_str);
185
186        Ok(())
187    }
188}