datafusion_comet_spark_expr/struct_funcs/
create_named_struct.rs1use arrow::record_batch::RecordBatch;
19use arrow_array::StructArray;
20use arrow_schema::{DataType, Field, Schema};
21use datafusion::logical_expr::ColumnarValue;
22use datafusion_common::Result as DataFusionResult;
23use datafusion_physical_expr::PhysicalExpr;
24use std::{
25 any::Any,
26 fmt::{Display, Formatter},
27 hash::Hash,
28 sync::Arc,
29};
30
31#[derive(Debug, Hash, PartialEq, Eq)]
32pub struct CreateNamedStruct {
33 values: Vec<Arc<dyn PhysicalExpr>>,
34 names: Vec<String>,
35}
36
37impl CreateNamedStruct {
38 pub fn new(values: Vec<Arc<dyn PhysicalExpr>>, names: Vec<String>) -> Self {
39 Self { values, names }
40 }
41
42 fn fields(&self, schema: &Schema) -> DataFusionResult<Vec<Field>> {
43 self.values
44 .iter()
45 .zip(&self.names)
46 .map(|(expr, name)| {
47 let data_type = expr.data_type(schema)?;
48 let nullable = expr.nullable(schema)?;
49 Ok(Field::new(name, data_type, nullable))
50 })
51 .collect()
52 }
53}
54
55impl PhysicalExpr for CreateNamedStruct {
56 fn as_any(&self) -> &dyn Any {
57 self
58 }
59
60 fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
61 let fields = self.fields(input_schema)?;
62 Ok(DataType::Struct(fields.into()))
63 }
64
65 fn nullable(&self, _input_schema: &Schema) -> DataFusionResult<bool> {
66 Ok(false)
67 }
68
69 fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
70 let values = self
71 .values
72 .iter()
73 .map(|expr| expr.evaluate(batch))
74 .collect::<datafusion_common::Result<Vec<_>>>()?;
75 let arrays = ColumnarValue::values_to_arrays(&values)?;
76 let fields = self.fields(&batch.schema())?;
77 Ok(ColumnarValue::Array(Arc::new(StructArray::new(
78 fields.into(),
79 arrays,
80 None,
81 ))))
82 }
83
84 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
85 self.values.iter().collect()
86 }
87
88 fn with_new_children(
89 self: Arc<Self>,
90 children: Vec<Arc<dyn PhysicalExpr>>,
91 ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
92 Ok(Arc::new(CreateNamedStruct::new(
93 children.clone(),
94 self.names.clone(),
95 )))
96 }
97}
98
99impl Display for CreateNamedStruct {
100 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
101 write!(
102 f,
103 "CreateNamedStruct [values: {:?}, names: {:?}]",
104 self.values, self.names
105 )
106 }
107}
108
109#[cfg(test)]
110mod test {
111 use super::CreateNamedStruct;
112 use arrow_array::{Array, DictionaryArray, Int32Array, RecordBatch, StringArray};
113 use arrow_schema::{DataType, Field, Schema};
114 use datafusion_common::Result;
115 use datafusion_expr::ColumnarValue;
116 use datafusion_physical_expr::expressions::Column;
117 use datafusion_physical_expr::PhysicalExpr;
118 use std::sync::Arc;
119
120 #[test]
121 fn test_create_struct_from_dict_encoded_i32() -> Result<()> {
122 let keys = Int32Array::from(vec![0, 1, 2]);
123 let values = Int32Array::from(vec![0, 111, 233]);
124 let dict = DictionaryArray::try_new(keys, Arc::new(values))?;
125 let data_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Int32));
126 let schema = Schema::new(vec![Field::new("a", data_type, false)]);
127 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)])?;
128 let field_names = vec!["a".to_string()];
129 let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], field_names);
130 let ColumnarValue::Array(x) = x.evaluate(&batch)? else {
131 unreachable!()
132 };
133 assert_eq!(3, x.len());
134 Ok(())
135 }
136
137 #[test]
138 fn test_create_struct_from_dict_encoded_string() -> Result<()> {
139 let keys = Int32Array::from(vec![0, 1, 2]);
140 let values = StringArray::from(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
141 let dict = DictionaryArray::try_new(keys, Arc::new(values))?;
142 let data_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
143 let schema = Schema::new(vec![Field::new("a", data_type, false)]);
144 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)])?;
145 let field_names = vec!["a".to_string()];
146 let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], field_names);
147 let ColumnarValue::Array(x) = x.evaluate(&batch)? else {
148 unreachable!()
149 };
150 assert_eq!(3, x.len());
151 Ok(())
152 }
153}