datafusion_ffi/udwf/
partition_evaluator_args.rs1use std::{collections::HashMap, sync::Arc};
19
20use crate::arrow_wrappers::WrappedSchema;
21use abi_stable::{std_types::RVec, StableAbi};
22use arrow::{
23 datatypes::{DataType, Field, Schema, SchemaRef},
24 error::ArrowError,
25 ffi::FFI_ArrowSchema,
26};
27use arrow_schema::FieldRef;
28use datafusion::{
29 error::{DataFusionError, Result},
30 logical_expr::function::PartitionEvaluatorArgs,
31 physical_plan::{expressions::Column, PhysicalExpr},
32 prelude::SessionContext,
33};
34use datafusion_proto::{
35 physical_plan::{
36 from_proto::parse_physical_expr, to_proto::serialize_physical_exprs,
37 DefaultPhysicalExtensionCodec,
38 },
39 protobuf::PhysicalExprNode,
40};
41use prost::Message;
42
43#[repr(C)]
47#[derive(Debug, StableAbi)]
48#[allow(non_camel_case_types)]
49pub struct FFI_PartitionEvaluatorArgs {
50 input_exprs: RVec<RVec<u8>>,
51 input_fields: RVec<WrappedSchema>,
52 is_reversed: bool,
53 ignore_nulls: bool,
54 schema: WrappedSchema,
55}
56
57impl TryFrom<PartitionEvaluatorArgs<'_>> for FFI_PartitionEvaluatorArgs {
58 type Error = DataFusionError;
59 fn try_from(args: PartitionEvaluatorArgs) -> Result<Self, DataFusionError> {
60 let required_columns: HashMap<usize, (&str, &DataType)> = args
68 .input_exprs()
69 .iter()
70 .zip(args.input_fields())
71 .filter_map(|(expr, field)| {
72 expr.as_any()
73 .downcast_ref::<Column>()
74 .map(|column| (column.index(), (column.name(), field.data_type())))
75 })
76 .collect();
77
78 let max_column = required_columns.keys().max();
79 let fields: Vec<_> = max_column
80 .map(|max_column| {
81 (0..(max_column + 1))
82 .map(|idx| match required_columns.get(&idx) {
83 Some((name, data_type)) => {
84 Field::new(*name, (*data_type).clone(), true)
85 }
86 None => Field::new(
87 format!("ffi_partition_evaluator_col_{idx}"),
88 DataType::Null,
89 true,
90 ),
91 })
92 .collect()
93 })
94 .unwrap_or_default();
95
96 let schema = Arc::new(Schema::new(fields));
97
98 let codec = DefaultPhysicalExtensionCodec {};
99 let input_exprs = serialize_physical_exprs(args.input_exprs(), &codec)?
100 .into_iter()
101 .map(|expr_node| expr_node.encode_to_vec().into())
102 .collect();
103
104 let input_fields = args
105 .input_fields()
106 .iter()
107 .map(|input_type| FFI_ArrowSchema::try_from(input_type).map(WrappedSchema))
108 .collect::<Result<Vec<_>, ArrowError>>()?
109 .into();
110
111 let schema: WrappedSchema = schema.into();
112
113 Ok(Self {
114 input_exprs,
115 input_fields,
116 schema,
117 is_reversed: args.is_reversed(),
118 ignore_nulls: args.ignore_nulls(),
119 })
120 }
121}
122
123pub struct ForeignPartitionEvaluatorArgs {
128 input_exprs: Vec<Arc<dyn PhysicalExpr>>,
129 input_fields: Vec<FieldRef>,
130 is_reversed: bool,
131 ignore_nulls: bool,
132}
133
134impl TryFrom<FFI_PartitionEvaluatorArgs> for ForeignPartitionEvaluatorArgs {
135 type Error = DataFusionError;
136
137 fn try_from(value: FFI_PartitionEvaluatorArgs) -> Result<Self> {
138 let default_ctx = SessionContext::new();
139 let codec = DefaultPhysicalExtensionCodec {};
140
141 let schema: SchemaRef = value.schema.into();
142
143 let input_exprs = value
144 .input_exprs
145 .into_iter()
146 .map(|input_expr_bytes| PhysicalExprNode::decode(input_expr_bytes.as_ref()))
147 .collect::<std::result::Result<Vec<_>, prost::DecodeError>>()
148 .map_err(|e| DataFusionError::Execution(e.to_string()))?
149 .iter()
150 .map(|expr_node| {
151 parse_physical_expr(expr_node, &default_ctx, &schema, &codec)
152 })
153 .collect::<Result<Vec<_>>>()?;
154
155 let input_fields = input_exprs
156 .iter()
157 .map(|expr| expr.return_field(&schema))
158 .collect::<Result<Vec<_>>>()?;
159
160 Ok(Self {
161 input_exprs,
162 input_fields,
163 is_reversed: value.is_reversed,
164 ignore_nulls: value.ignore_nulls,
165 })
166 }
167}
168
169impl<'a> From<&'a ForeignPartitionEvaluatorArgs> for PartitionEvaluatorArgs<'a> {
170 fn from(value: &'a ForeignPartitionEvaluatorArgs) -> Self {
171 PartitionEvaluatorArgs::new(
172 &value.input_exprs,
173 &value.input_fields,
174 value.is_reversed,
175 value.ignore_nulls,
176 )
177 }
178}
179
180#[cfg(test)]
181mod tests {}