datafusion_ffi/udaf/
accumulator_args.rs1use std::sync::Arc;
19
20use abi_stable::StableAbi;
21use abi_stable::std_types::{RString, RVec};
22use arrow::datatypes::Schema;
23use arrow::ffi::FFI_ArrowSchema;
24use arrow_schema::FieldRef;
25use datafusion_common::error::DataFusionError;
26use datafusion_expr::function::AccumulatorArgs;
27use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
28
29use crate::arrow_wrappers::WrappedSchema;
30use crate::physical_expr::FFI_PhysicalExpr;
31use crate::physical_expr::sort::FFI_PhysicalSortExpr;
32use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped};
33
34#[repr(C)]
38#[derive(Debug, StableAbi)]
39pub struct FFI_AccumulatorArgs {
40 return_field: WrappedSchema,
41 schema: WrappedSchema,
42 ignore_nulls: bool,
43 order_bys: RVec<FFI_PhysicalSortExpr>,
44 is_reversed: bool,
45 name: RString,
46 is_distinct: bool,
47 exprs: RVec<FFI_PhysicalExpr>,
48 expr_fields: RVec<WrappedSchema>,
49}
50
51impl TryFrom<AccumulatorArgs<'_>> for FFI_AccumulatorArgs {
52 type Error = DataFusionError;
53 fn try_from(args: AccumulatorArgs) -> Result<Self, DataFusionError> {
54 let return_field =
55 WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?);
56 let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?);
57
58 let order_bys: RVec<_> = args
59 .order_bys
60 .iter()
61 .map(FFI_PhysicalSortExpr::from)
62 .collect();
63
64 let exprs = args
65 .exprs
66 .iter()
67 .map(Arc::clone)
68 .map(FFI_PhysicalExpr::from)
69 .collect();
70
71 let expr_fields = vec_fieldref_to_rvec_wrapped(args.expr_fields)?;
72
73 Ok(Self {
74 return_field,
75 schema,
76 ignore_nulls: args.ignore_nulls,
77 order_bys,
78 is_reversed: args.is_reversed,
79 name: args.name.into(),
80 is_distinct: args.is_distinct,
81 exprs,
82 expr_fields,
83 })
84 }
85}
86
87pub struct ForeignAccumulatorArgs {
92 pub return_field: FieldRef,
93 pub schema: Schema,
94 pub expr_fields: Vec<FieldRef>,
95 pub ignore_nulls: bool,
96 pub order_bys: Vec<PhysicalSortExpr>,
97 pub is_reversed: bool,
98 pub name: String,
99 pub is_distinct: bool,
100 pub exprs: Vec<Arc<dyn PhysicalExpr>>,
101}
102
103impl TryFrom<FFI_AccumulatorArgs> for ForeignAccumulatorArgs {
104 type Error = DataFusionError;
105
106 fn try_from(value: FFI_AccumulatorArgs) -> Result<Self, Self::Error> {
107 let return_field = Arc::new((&value.return_field.0).try_into()?);
108 let schema = Schema::try_from(&value.schema.0)?;
109
110 let order_bys = value.order_bys.iter().map(PhysicalSortExpr::from).collect();
111
112 let exprs = value
113 .exprs
114 .iter()
115 .map(<Arc<dyn PhysicalExpr>>::from)
116 .collect();
117
118 let expr_fields = rvec_wrapped_to_vec_fieldref(&value.expr_fields)?;
119
120 Ok(Self {
121 return_field,
122 schema,
123 expr_fields,
124 ignore_nulls: value.ignore_nulls,
125 order_bys,
126 is_reversed: value.is_reversed,
127 name: value.name.to_string(),
128 is_distinct: value.is_distinct,
129 exprs,
130 })
131 }
132}
133
134impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> {
135 fn from(value: &'a ForeignAccumulatorArgs) -> Self {
136 Self {
137 return_field: Arc::clone(&value.return_field),
138 schema: &value.schema,
139 expr_fields: &value.expr_fields,
140 ignore_nulls: value.ignore_nulls,
141 order_bys: &value.order_bys,
142 is_reversed: value.is_reversed,
143 name: value.name.as_str(),
144 is_distinct: value.is_distinct,
145 exprs: &value.exprs,
146 }
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use arrow::datatypes::{DataType, Field, Schema};
153 use datafusion::error::Result;
154 use datafusion::logical_expr::function::AccumulatorArgs;
155 use datafusion::physical_expr::PhysicalSortExpr;
156 use datafusion::physical_plan::expressions::col;
157
158 use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs};
159
160 #[test]
161 fn test_round_trip_accumulator_args() -> Result<()> {
162 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
163 let orig_args = AccumulatorArgs {
164 return_field: Field::new("f", DataType::Float64, true).into(),
165 schema: &schema,
166 expr_fields: &[Field::new("a", DataType::Int32, true).into()],
167 ignore_nulls: false,
168 order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
169 is_reversed: false,
170 name: "round_trip",
171 is_distinct: true,
172 exprs: &[col("a", &schema)?],
173 };
174 let orig_str = format!("{orig_args:?}");
175
176 let ffi_args = FFI_AccumulatorArgs::try_from(orig_args)?;
177 let foreign_args: ForeignAccumulatorArgs = ffi_args.try_into()?;
178 let round_trip_args: AccumulatorArgs = (&foreign_args).into();
179
180 let round_trip_str = format!("{round_trip_args:?}");
181
182 assert_eq!(orig_str, round_trip_str);
185
186 Ok(())
187 }
188}