datafusion_ffi/udwf/
partition_evaluator_args.rs1use std::{collections::HashMap, sync::Arc};
19
20use crate::arrow_wrappers::WrappedSchema;
21use abi_stable::{std_types::RVec, StableAbi};
22use arrow::{
23 datatypes::{DataType, Field, Schema, SchemaRef},
24 error::ArrowError,
25 ffi::FFI_ArrowSchema,
26};
27use arrow_schema::FieldRef;
28use datafusion::{
29 error::{DataFusionError, Result},
30 logical_expr::function::PartitionEvaluatorArgs,
31 physical_plan::{expressions::Column, PhysicalExpr},
32 prelude::SessionContext,
33};
34use datafusion_common::exec_datafusion_err;
35use datafusion_proto::{
36 physical_plan::{
37 from_proto::parse_physical_expr, to_proto::serialize_physical_exprs,
38 DefaultPhysicalExtensionCodec,
39 },
40 protobuf::PhysicalExprNode,
41};
42use prost::Message;
43
44#[repr(C)]
48#[derive(Debug, StableAbi)]
49#[allow(non_camel_case_types)]
50pub struct FFI_PartitionEvaluatorArgs {
51 input_exprs: RVec<RVec<u8>>,
52 input_fields: RVec<WrappedSchema>,
53 is_reversed: bool,
54 ignore_nulls: bool,
55 schema: WrappedSchema,
56}
57
58impl TryFrom<PartitionEvaluatorArgs<'_>> for FFI_PartitionEvaluatorArgs {
59 type Error = DataFusionError;
60 fn try_from(args: PartitionEvaluatorArgs) -> Result<Self, DataFusionError> {
61 let required_columns: HashMap<usize, (&str, &DataType)> = args
69 .input_exprs()
70 .iter()
71 .zip(args.input_fields())
72 .filter_map(|(expr, field)| {
73 expr.as_any()
74 .downcast_ref::<Column>()
75 .map(|column| (column.index(), (column.name(), field.data_type())))
76 })
77 .collect();
78
79 let max_column = required_columns.keys().max();
80 let fields: Vec<_> = max_column
81 .map(|max_column| {
82 (0..(max_column + 1))
83 .map(|idx| match required_columns.get(&idx) {
84 Some((name, data_type)) => {
85 Field::new(*name, (*data_type).clone(), true)
86 }
87 None => Field::new(
88 format!("ffi_partition_evaluator_col_{idx}"),
89 DataType::Null,
90 true,
91 ),
92 })
93 .collect()
94 })
95 .unwrap_or_default();
96
97 let schema = Arc::new(Schema::new(fields));
98
99 let codec = DefaultPhysicalExtensionCodec {};
100 let input_exprs = serialize_physical_exprs(args.input_exprs(), &codec)?
101 .into_iter()
102 .map(|expr_node| expr_node.encode_to_vec().into())
103 .collect();
104
105 let input_fields = args
106 .input_fields()
107 .iter()
108 .map(|input_type| FFI_ArrowSchema::try_from(input_type).map(WrappedSchema))
109 .collect::<Result<Vec<_>, ArrowError>>()?
110 .into();
111
112 let schema: WrappedSchema = schema.into();
113
114 Ok(Self {
115 input_exprs,
116 input_fields,
117 schema,
118 is_reversed: args.is_reversed(),
119 ignore_nulls: args.ignore_nulls(),
120 })
121 }
122}
123
124pub struct ForeignPartitionEvaluatorArgs {
129 input_exprs: Vec<Arc<dyn PhysicalExpr>>,
130 input_fields: Vec<FieldRef>,
131 is_reversed: bool,
132 ignore_nulls: bool,
133}
134
135impl TryFrom<FFI_PartitionEvaluatorArgs> for ForeignPartitionEvaluatorArgs {
136 type Error = DataFusionError;
137
138 fn try_from(value: FFI_PartitionEvaluatorArgs) -> Result<Self> {
139 let default_ctx = SessionContext::new();
140 let codec = DefaultPhysicalExtensionCodec {};
141
142 let schema: SchemaRef = value.schema.into();
143
144 let input_exprs = value
145 .input_exprs
146 .into_iter()
147 .map(|input_expr_bytes| PhysicalExprNode::decode(input_expr_bytes.as_ref()))
148 .collect::<std::result::Result<Vec<_>, prost::DecodeError>>()
149 .map_err(|e| exec_datafusion_err!("Failed to decode PhysicalExprNode: {e}"))?
150 .iter()
151 .map(|expr_node| {
152 parse_physical_expr(expr_node, &default_ctx.task_ctx(), &schema, &codec)
153 })
154 .collect::<Result<Vec<_>>>()?;
155
156 let input_fields = input_exprs
157 .iter()
158 .map(|expr| expr.return_field(&schema))
159 .collect::<Result<Vec<_>>>()?;
160
161 Ok(Self {
162 input_exprs,
163 input_fields,
164 is_reversed: value.is_reversed,
165 ignore_nulls: value.ignore_nulls,
166 })
167 }
168}
169
170impl<'a> From<&'a ForeignPartitionEvaluatorArgs> for PartitionEvaluatorArgs<'a> {
171 fn from(value: &'a ForeignPartitionEvaluatorArgs) -> Self {
172 PartitionEvaluatorArgs::new(
173 &value.input_exprs,
174 &value.input_fields,
175 value.is_reversed,
176 value.ignore_nulls,
177 )
178 }
179}
180
181#[cfg(test)]
182mod tests {}