datafusion_comet_spark_expr/struct_funcs/
create_named_struct.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::record_batch::RecordBatch;
19use arrow_array::StructArray;
20use arrow_schema::{DataType, Field, Schema};
21use datafusion::logical_expr::ColumnarValue;
22use datafusion_common::Result as DataFusionResult;
23use datafusion_physical_expr::PhysicalExpr;
24use std::{
25    any::Any,
26    fmt::{Display, Formatter},
27    hash::Hash,
28    sync::Arc,
29};
30
31#[derive(Debug, Hash, PartialEq, Eq)]
32pub struct CreateNamedStruct {
33    values: Vec<Arc<dyn PhysicalExpr>>,
34    names: Vec<String>,
35}
36
37impl CreateNamedStruct {
38    pub fn new(values: Vec<Arc<dyn PhysicalExpr>>, names: Vec<String>) -> Self {
39        Self { values, names }
40    }
41
42    fn fields(&self, schema: &Schema) -> DataFusionResult<Vec<Field>> {
43        self.values
44            .iter()
45            .zip(&self.names)
46            .map(|(expr, name)| {
47                let data_type = expr.data_type(schema)?;
48                let nullable = expr.nullable(schema)?;
49                Ok(Field::new(name, data_type, nullable))
50            })
51            .collect()
52    }
53}
54
55impl PhysicalExpr for CreateNamedStruct {
56    fn as_any(&self) -> &dyn Any {
57        self
58    }
59
60    fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
61        let fields = self.fields(input_schema)?;
62        Ok(DataType::Struct(fields.into()))
63    }
64
65    fn nullable(&self, _input_schema: &Schema) -> DataFusionResult<bool> {
66        Ok(false)
67    }
68
69    fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
70        let values = self
71            .values
72            .iter()
73            .map(|expr| expr.evaluate(batch))
74            .collect::<datafusion_common::Result<Vec<_>>>()?;
75        let arrays = ColumnarValue::values_to_arrays(&values)?;
76        let fields = self.fields(&batch.schema())?;
77        Ok(ColumnarValue::Array(Arc::new(StructArray::new(
78            fields.into(),
79            arrays,
80            None,
81        ))))
82    }
83
84    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
85        self.values.iter().collect()
86    }
87
88    fn with_new_children(
89        self: Arc<Self>,
90        children: Vec<Arc<dyn PhysicalExpr>>,
91    ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
92        Ok(Arc::new(CreateNamedStruct::new(
93            children.clone(),
94            self.names.clone(),
95        )))
96    }
97}
98
99impl Display for CreateNamedStruct {
100    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
101        write!(
102            f,
103            "CreateNamedStruct [values: {:?}, names: {:?}]",
104            self.values, self.names
105        )
106    }
107}
108
109#[cfg(test)]
110mod test {
111    use super::CreateNamedStruct;
112    use arrow_array::{Array, DictionaryArray, Int32Array, RecordBatch, StringArray};
113    use arrow_schema::{DataType, Field, Schema};
114    use datafusion_common::Result;
115    use datafusion_expr::ColumnarValue;
116    use datafusion_physical_expr::expressions::Column;
117    use datafusion_physical_expr::PhysicalExpr;
118    use std::sync::Arc;
119
120    #[test]
121    fn test_create_struct_from_dict_encoded_i32() -> Result<()> {
122        let keys = Int32Array::from(vec![0, 1, 2]);
123        let values = Int32Array::from(vec![0, 111, 233]);
124        let dict = DictionaryArray::try_new(keys, Arc::new(values))?;
125        let data_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Int32));
126        let schema = Schema::new(vec![Field::new("a", data_type, false)]);
127        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)])?;
128        let field_names = vec!["a".to_string()];
129        let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], field_names);
130        let ColumnarValue::Array(x) = x.evaluate(&batch)? else {
131            unreachable!()
132        };
133        assert_eq!(3, x.len());
134        Ok(())
135    }
136
137    #[test]
138    fn test_create_struct_from_dict_encoded_string() -> Result<()> {
139        let keys = Int32Array::from(vec![0, 1, 2]);
140        let values = StringArray::from(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
141        let dict = DictionaryArray::try_new(keys, Arc::new(values))?;
142        let data_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
143        let schema = Schema::new(vec![Field::new("a", data_type, false)]);
144        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)])?;
145        let field_names = vec!["a".to_string()];
146        let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], field_names);
147        let ColumnarValue::Array(x) = x.evaluate(&batch)? else {
148            unreachable!()
149        };
150        assert_eq!(3, x.len());
151        Ok(())
152    }
153}