1use std::sync::Arc;
19
20use crate::arrow_wrappers::WrappedSchema;
21use abi_stable::{
22 std_types::{RString, RVec},
23 StableAbi,
24};
25use arrow::{datatypes::Schema, ffi::FFI_ArrowSchema};
26use arrow_schema::FieldRef;
27use datafusion::{
28 error::DataFusionError,
29 logical_expr::function::AccumulatorArgs,
30 physical_expr::{PhysicalExpr, PhysicalSortExpr},
31 prelude::SessionContext,
32};
33use datafusion_common::exec_datafusion_err;
34use datafusion_proto::{
35 physical_plan::{
36 from_proto::{parse_physical_exprs, parse_physical_sort_exprs},
37 to_proto::{serialize_physical_exprs, serialize_physical_sort_exprs},
38 DefaultPhysicalExtensionCodec,
39 },
40 protobuf::PhysicalAggregateExprNode,
41};
42use prost::Message;
43
44#[repr(C)]
48#[derive(Debug, StableAbi)]
49#[allow(non_camel_case_types)]
50pub struct FFI_AccumulatorArgs {
51 return_field: WrappedSchema,
52 schema: WrappedSchema,
53 is_reversed: bool,
54 name: RString,
55 physical_expr_def: RVec<u8>,
56}
57
58impl TryFrom<AccumulatorArgs<'_>> for FFI_AccumulatorArgs {
59 type Error = DataFusionError;
60
61 fn try_from(args: AccumulatorArgs) -> Result<Self, Self::Error> {
62 let return_field =
63 WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?);
64 let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?);
65
66 let codec = DefaultPhysicalExtensionCodec {};
67 let ordering_req =
68 serialize_physical_sort_exprs(args.order_bys.to_owned(), &codec)?;
69
70 let expr = serialize_physical_exprs(args.exprs, &codec)?;
71
72 let physical_expr_def = PhysicalAggregateExprNode {
73 expr,
74 ordering_req,
75 distinct: args.is_distinct,
76 ignore_nulls: args.ignore_nulls,
77 fun_definition: None,
78 aggregate_function: None,
79 human_display: args.name.to_string(),
80 };
81 let physical_expr_def = physical_expr_def.encode_to_vec().into();
82
83 Ok(Self {
84 return_field,
85 schema,
86 is_reversed: args.is_reversed,
87 name: args.name.into(),
88 physical_expr_def,
89 })
90 }
91}
92
93pub struct ForeignAccumulatorArgs {
98 pub return_field: FieldRef,
99 pub schema: Schema,
100 pub expr_fields: Vec<FieldRef>,
101 pub ignore_nulls: bool,
102 pub order_bys: Vec<PhysicalSortExpr>,
103 pub is_reversed: bool,
104 pub name: String,
105 pub is_distinct: bool,
106 pub exprs: Vec<Arc<dyn PhysicalExpr>>,
107}
108
109impl TryFrom<FFI_AccumulatorArgs> for ForeignAccumulatorArgs {
110 type Error = DataFusionError;
111
112 fn try_from(value: FFI_AccumulatorArgs) -> Result<Self, Self::Error> {
113 let proto_def = PhysicalAggregateExprNode::decode(
114 value.physical_expr_def.as_ref(),
115 )
116 .map_err(|e| {
117 exec_datafusion_err!("Failed to decode PhysicalAggregateExprNode: {e}")
118 })?;
119
120 let return_field = Arc::new((&value.return_field.0).try_into()?);
121 let schema = Schema::try_from(&value.schema.0)?;
122
123 let default_ctx = SessionContext::new();
124 let task_ctx = default_ctx.task_ctx();
125 let codex = DefaultPhysicalExtensionCodec {};
126
127 let order_bys = parse_physical_sort_exprs(
128 &proto_def.ordering_req,
129 &task_ctx,
130 &schema,
131 &codex,
132 )?;
133
134 let exprs = parse_physical_exprs(&proto_def.expr, &task_ctx, &schema, &codex)?;
135
136 let expr_fields = exprs
137 .iter()
138 .map(|e| e.return_field(&schema))
139 .collect::<Result<Vec<_>, _>>()?;
140
141 Ok(Self {
142 return_field,
143 schema,
144 expr_fields,
145 ignore_nulls: proto_def.ignore_nulls,
146 order_bys,
147 is_reversed: value.is_reversed,
148 name: value.name.to_string(),
149 is_distinct: proto_def.distinct,
150 exprs,
151 })
152 }
153}
154
155impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> {
156 fn from(value: &'a ForeignAccumulatorArgs) -> Self {
157 Self {
158 return_field: Arc::clone(&value.return_field),
159 schema: &value.schema,
160 expr_fields: &value.expr_fields,
161 ignore_nulls: value.ignore_nulls,
162 order_bys: &value.order_bys,
163 is_reversed: value.is_reversed,
164 name: value.name.as_str(),
165 is_distinct: value.is_distinct,
166 exprs: &value.exprs,
167 }
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs};
174 use arrow::datatypes::{DataType, Field, Schema};
175 use datafusion::{
176 error::Result, logical_expr::function::AccumulatorArgs,
177 physical_expr::PhysicalSortExpr, physical_plan::expressions::col,
178 };
179
180 #[test]
181 fn test_round_trip_accumulator_args() -> Result<()> {
182 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
183 let orig_args = AccumulatorArgs {
184 return_field: Field::new("f", DataType::Float64, true).into(),
185 schema: &schema,
186 expr_fields: &[Field::new("a", DataType::Int32, true).into()],
187 ignore_nulls: false,
188 order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
189 is_reversed: false,
190 name: "round_trip",
191 is_distinct: true,
192 exprs: &[col("a", &schema)?],
193 };
194 let orig_str = format!("{orig_args:?}");
195
196 let ffi_args: FFI_AccumulatorArgs = orig_args.try_into()?;
197 let foreign_args: ForeignAccumulatorArgs = ffi_args.try_into()?;
198 let round_trip_args: AccumulatorArgs = (&foreign_args).into();
199
200 let round_trip_str = format!("{round_trip_args:?}");
201
202 assert_eq!(orig_str, round_trip_str);
205
206 Ok(())
207 }
208}