1use std::sync::Arc;
19
20use crate::arrow_wrappers::WrappedSchema;
21use abi_stable::{
22 std_types::{RString, RVec},
23 StableAbi,
24};
25use arrow::{datatypes::Schema, ffi::FFI_ArrowSchema};
26use arrow_schema::FieldRef;
27use datafusion::{
28 error::DataFusionError,
29 logical_expr::function::AccumulatorArgs,
30 physical_expr::{PhysicalExpr, PhysicalSortExpr},
31 prelude::SessionContext,
32};
33use datafusion_proto::{
34 physical_plan::{
35 from_proto::{parse_physical_exprs, parse_physical_sort_exprs},
36 to_proto::{serialize_physical_exprs, serialize_physical_sort_exprs},
37 DefaultPhysicalExtensionCodec,
38 },
39 protobuf::PhysicalAggregateExprNode,
40};
41use prost::Message;
42
43#[repr(C)]
47#[derive(Debug, StableAbi)]
48#[allow(non_camel_case_types)]
49pub struct FFI_AccumulatorArgs {
50 return_field: WrappedSchema,
51 schema: WrappedSchema,
52 is_reversed: bool,
53 name: RString,
54 physical_expr_def: RVec<u8>,
55}
56
57impl TryFrom<AccumulatorArgs<'_>> for FFI_AccumulatorArgs {
58 type Error = DataFusionError;
59
60 fn try_from(args: AccumulatorArgs) -> Result<Self, Self::Error> {
61 let return_field =
62 WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?);
63 let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?);
64
65 let codec = DefaultPhysicalExtensionCodec {};
66 let ordering_req =
67 serialize_physical_sort_exprs(args.order_bys.to_owned(), &codec)?;
68
69 let expr = serialize_physical_exprs(args.exprs, &codec)?;
70
71 let physical_expr_def = PhysicalAggregateExprNode {
72 expr,
73 ordering_req,
74 distinct: args.is_distinct,
75 ignore_nulls: args.ignore_nulls,
76 fun_definition: None,
77 aggregate_function: None,
78 human_display: args.name.to_string(),
79 };
80 let physical_expr_def = physical_expr_def.encode_to_vec().into();
81
82 Ok(Self {
83 return_field,
84 schema,
85 is_reversed: args.is_reversed,
86 name: args.name.into(),
87 physical_expr_def,
88 })
89 }
90}
91
92pub struct ForeignAccumulatorArgs {
97 pub return_field: FieldRef,
98 pub schema: Schema,
99 pub ignore_nulls: bool,
100 pub order_bys: Vec<PhysicalSortExpr>,
101 pub is_reversed: bool,
102 pub name: String,
103 pub is_distinct: bool,
104 pub exprs: Vec<Arc<dyn PhysicalExpr>>,
105}
106
107impl TryFrom<FFI_AccumulatorArgs> for ForeignAccumulatorArgs {
108 type Error = DataFusionError;
109
110 fn try_from(value: FFI_AccumulatorArgs) -> Result<Self, Self::Error> {
111 let proto_def =
112 PhysicalAggregateExprNode::decode(value.physical_expr_def.as_ref())
113 .map_err(|e| DataFusionError::Execution(e.to_string()))?;
114
115 let return_field = Arc::new((&value.return_field.0).try_into()?);
116 let schema = Schema::try_from(&value.schema.0)?;
117
118 let default_ctx = SessionContext::new();
119 let codex = DefaultPhysicalExtensionCodec {};
120
121 let order_bys = parse_physical_sort_exprs(
122 &proto_def.ordering_req,
123 &default_ctx,
124 &schema,
125 &codex,
126 )?;
127
128 let exprs = parse_physical_exprs(&proto_def.expr, &default_ctx, &schema, &codex)?;
129
130 Ok(Self {
131 return_field,
132 schema,
133 ignore_nulls: proto_def.ignore_nulls,
134 order_bys,
135 is_reversed: value.is_reversed,
136 name: value.name.to_string(),
137 is_distinct: proto_def.distinct,
138 exprs,
139 })
140 }
141}
142
143impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> {
144 fn from(value: &'a ForeignAccumulatorArgs) -> Self {
145 Self {
146 return_field: Arc::clone(&value.return_field),
147 schema: &value.schema,
148 ignore_nulls: value.ignore_nulls,
149 order_bys: &value.order_bys,
150 is_reversed: value.is_reversed,
151 name: value.name.as_str(),
152 is_distinct: value.is_distinct,
153 exprs: &value.exprs,
154 }
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs};
161 use arrow::datatypes::{DataType, Field, Schema};
162 use datafusion::{
163 error::Result, logical_expr::function::AccumulatorArgs,
164 physical_expr::PhysicalSortExpr, physical_plan::expressions::col,
165 };
166
167 #[test]
168 fn test_round_trip_accumulator_args() -> Result<()> {
169 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
170 let orig_args = AccumulatorArgs {
171 return_field: Field::new("f", DataType::Float64, true).into(),
172 schema: &schema,
173 ignore_nulls: false,
174 order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
175 is_reversed: false,
176 name: "round_trip",
177 is_distinct: true,
178 exprs: &[col("a", &schema)?],
179 };
180 let orig_str = format!("{orig_args:?}");
181
182 let ffi_args: FFI_AccumulatorArgs = orig_args.try_into()?;
183 let foreign_args: ForeignAccumulatorArgs = ffi_args.try_into()?;
184 let round_trip_args: AccumulatorArgs = (&foreign_args).into();
185
186 let round_trip_str = format!("{round_trip_args:?}");
187
188 assert_eq!(orig_str, round_trip_str);
191
192 Ok(())
193 }
194}