datafusion_comet_spark_expr/array_funcs/
get_array_struct_fields.rs1use arrow::record_batch::RecordBatch;
19use arrow_array::{Array, GenericListArray, OffsetSizeTrait, StructArray};
20use arrow_schema::{DataType, FieldRef, Schema};
21use datafusion::logical_expr::ColumnarValue;
22use datafusion_common::{
23 cast::{as_large_list_array, as_list_array},
24 internal_err, DataFusionError, Result as DataFusionResult,
25};
26use datafusion_physical_expr::PhysicalExpr;
27use std::hash::Hash;
28use std::{
29 any::Any,
30 fmt::{Debug, Display, Formatter},
31 sync::Arc,
32};
33
34#[derive(Debug, Eq)]
35pub struct GetArrayStructFields {
36 child: Arc<dyn PhysicalExpr>,
37 ordinal: usize,
38}
39
40impl Hash for GetArrayStructFields {
41 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
42 self.child.hash(state);
43 self.ordinal.hash(state);
44 }
45}
46impl PartialEq for GetArrayStructFields {
47 fn eq(&self, other: &Self) -> bool {
48 self.child.eq(&other.child) && self.ordinal.eq(&other.ordinal)
49 }
50}
51
52impl GetArrayStructFields {
53 pub fn new(child: Arc<dyn PhysicalExpr>, ordinal: usize) -> Self {
54 Self { child, ordinal }
55 }
56
57 fn list_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
58 match self.child.data_type(input_schema)? {
59 DataType::List(field) | DataType::LargeList(field) => Ok(field),
60 data_type => Err(DataFusionError::Internal(format!(
61 "Unexpected data type in GetArrayStructFields: {:?}",
62 data_type
63 ))),
64 }
65 }
66
67 fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
68 match self.list_field(input_schema)?.data_type() {
69 DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])),
70 data_type => Err(DataFusionError::Internal(format!(
71 "Unexpected data type in GetArrayStructFields: {:?}",
72 data_type
73 ))),
74 }
75 }
76}
77
78impl PhysicalExpr for GetArrayStructFields {
79 fn as_any(&self) -> &dyn Any {
80 self
81 }
82
83 fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
84 let struct_field = self.child_field(input_schema)?;
85 match self.child.data_type(input_schema)? {
86 DataType::List(_) => Ok(DataType::List(struct_field)),
87 DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)),
88 data_type => Err(DataFusionError::Internal(format!(
89 "Unexpected data type in GetArrayStructFields: {:?}",
90 data_type
91 ))),
92 }
93 }
94
95 fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
96 Ok(self.list_field(input_schema)?.is_nullable()
97 || self.child_field(input_schema)?.is_nullable())
98 }
99
100 fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
101 let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?;
102
103 match child_value.data_type() {
104 DataType::List(_) => {
105 let list_array = as_list_array(&child_value)?;
106
107 get_array_struct_fields(list_array, self.ordinal)
108 }
109 DataType::LargeList(_) => {
110 let list_array = as_large_list_array(&child_value)?;
111
112 get_array_struct_fields(list_array, self.ordinal)
113 }
114 data_type => Err(DataFusionError::Internal(format!(
115 "Unexpected child type for ListExtract: {:?}",
116 data_type
117 ))),
118 }
119 }
120
121 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
122 vec![&self.child]
123 }
124
125 fn with_new_children(
126 self: Arc<Self>,
127 children: Vec<Arc<dyn PhysicalExpr>>,
128 ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
129 match children.len() {
130 1 => Ok(Arc::new(GetArrayStructFields::new(
131 Arc::clone(&children[0]),
132 self.ordinal,
133 ))),
134 _ => internal_err!("GetArrayStructFields should have exactly one child"),
135 }
136 }
137}
138
139fn get_array_struct_fields<O: OffsetSizeTrait>(
140 list_array: &GenericListArray<O>,
141 ordinal: usize,
142) -> DataFusionResult<ColumnarValue> {
143 let values = list_array
144 .values()
145 .as_any()
146 .downcast_ref::<StructArray>()
147 .expect("A struct is expected");
148
149 let column = Arc::clone(values.column(ordinal));
150 let field = Arc::clone(&values.fields()[ordinal]);
151
152 let offsets = list_array.offsets();
153 let array = GenericListArray::new(field, offsets.clone(), column, list_array.nulls().cloned());
154
155 Ok(ColumnarValue::Array(Arc::new(array)))
156}
157
158impl Display for GetArrayStructFields {
159 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
160 write!(
161 f,
162 "GetArrayStructFields [child: {:?}, ordinal: {:?}]",
163 self.child, self.ordinal
164 )
165 }
166}