1use crate::physical_expr::PhysicalExpr;
21use arrow::{
22 compute::CastOptions,
23 datatypes::{DataType, FieldRef, Schema},
24 record_batch::RecordBatch,
25};
26use datafusion_common::{
27 format::DEFAULT_CAST_OPTIONS, nested_struct::cast_column, Result, ScalarValue,
28};
29use datafusion_expr_common::columnar_value::ColumnarValue;
30use std::{
31 any::Any,
32 fmt::{self, Display},
33 hash::Hash,
34 sync::Arc,
35};
36#[derive(Debug, Clone, Eq)]
50pub struct CastColumnExpr {
51 expr: Arc<dyn PhysicalExpr>,
53 input_field: FieldRef,
55 target_field: FieldRef,
57 cast_options: CastOptions<'static>,
59}
60
61impl PartialEq for CastColumnExpr {
64 fn eq(&self, other: &Self) -> bool {
65 self.expr.eq(&other.expr)
66 && self.input_field.eq(&other.input_field)
67 && self.target_field.eq(&other.target_field)
68 && self.cast_options.eq(&other.cast_options)
69 }
70}
71
72impl Hash for CastColumnExpr {
73 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
74 self.expr.hash(state);
75 self.input_field.hash(state);
76 self.target_field.hash(state);
77 self.cast_options.hash(state);
78 }
79}
80
81impl CastColumnExpr {
82 pub fn new(
84 expr: Arc<dyn PhysicalExpr>,
85 input_field: FieldRef,
86 target_field: FieldRef,
87 cast_options: Option<CastOptions<'static>>,
88 ) -> Self {
89 Self {
90 expr,
91 input_field,
92 target_field,
93 cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS),
94 }
95 }
96
97 pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
99 &self.expr
100 }
101
102 pub fn input_field(&self) -> &FieldRef {
104 &self.input_field
105 }
106
107 pub fn target_field(&self) -> &FieldRef {
109 &self.target_field
110 }
111}
112
113impl Display for CastColumnExpr {
114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115 write!(
116 f,
117 "CAST_COLUMN({} AS {:?})",
118 self.expr,
119 self.target_field.data_type()
120 )
121 }
122}
123
124impl PhysicalExpr for CastColumnExpr {
125 fn as_any(&self) -> &dyn Any {
126 self
127 }
128
129 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
130 Ok(self.target_field.data_type().clone())
131 }
132
133 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
134 Ok(self.target_field.is_nullable())
135 }
136
137 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
138 let value = self.expr.evaluate(batch)?;
139 match value {
140 ColumnarValue::Array(array) => {
141 let casted =
142 cast_column(&array, self.target_field.as_ref(), &self.cast_options)?;
143 Ok(ColumnarValue::Array(casted))
144 }
145 ColumnarValue::Scalar(scalar) => {
146 let as_array = scalar.to_array_of_size(1)?;
147 let casted = cast_column(
148 &as_array,
149 self.target_field.as_ref(),
150 &self.cast_options,
151 )?;
152 let result = ScalarValue::try_from_array(casted.as_ref(), 0)?;
153 Ok(ColumnarValue::Scalar(result))
154 }
155 }
156 }
157
158 fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
159 Ok(Arc::clone(&self.target_field))
160 }
161
162 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
163 vec![&self.expr]
164 }
165
166 fn with_new_children(
167 self: Arc<Self>,
168 mut children: Vec<Arc<dyn PhysicalExpr>>,
169 ) -> Result<Arc<dyn PhysicalExpr>> {
170 assert_eq!(children.len(), 1);
171 let child = children.pop().expect("CastColumnExpr child");
172 Ok(Arc::new(Self::new(
173 child,
174 Arc::clone(&self.input_field),
175 Arc::clone(&self.target_field),
176 Some(self.cast_options.clone()),
177 )))
178 }
179
180 fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181 Display::fmt(self, f)
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188
189 use crate::expressions::{Column, Literal};
190 use arrow::{
191 array::{Array, ArrayRef, BooleanArray, Int32Array, StringArray, StructArray},
192 datatypes::{DataType, Field, Fields, SchemaRef},
193 };
194 use datafusion_common::{
195 cast::{as_int64_array, as_string_array, as_struct_array, as_uint8_array},
196 Result as DFResult, ScalarValue,
197 };
198
199 fn make_schema(field: &Field) -> SchemaRef {
200 Arc::new(Schema::new(vec![field.clone()]))
201 }
202
203 fn make_struct_array(fields: Fields, arrays: Vec<ArrayRef>) -> StructArray {
204 StructArray::new(fields, arrays, None)
205 }
206
207 #[test]
208 fn cast_primitive_array() -> DFResult<()> {
209 let input_field = Field::new("a", DataType::Int32, true);
210 let target_field = Field::new("a", DataType::Int64, true);
211 let schema = make_schema(&input_field);
212
213 let values = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
214 let batch = RecordBatch::try_new(Arc::clone(&schema), vec![values])?;
215
216 let column = Arc::new(Column::new_with_schema("a", schema.as_ref())?);
217 let expr = CastColumnExpr::new(
218 column,
219 Arc::new(input_field.clone()),
220 Arc::new(target_field.clone()),
221 None,
222 );
223
224 let result = expr.evaluate(&batch)?;
225 let ColumnarValue::Array(array) = result else {
226 panic!("expected array");
227 };
228 let casted = as_int64_array(array.as_ref())?;
229 assert_eq!(casted.value(0), 1);
230 assert!(casted.is_null(1));
231 assert_eq!(casted.value(2), 3);
232 Ok(())
233 }
234
235 #[test]
236 fn cast_struct_array_missing_child() -> DFResult<()> {
237 let source_a = Field::new("a", DataType::Int32, true);
238 let source_b = Field::new("b", DataType::Utf8, true);
239 let input_field = Field::new(
240 "s",
241 DataType::Struct(
242 vec![Arc::new(source_a.clone()), Arc::new(source_b.clone())].into(),
243 ),
244 true,
245 );
246 let target_a = Field::new("a", DataType::Int64, true);
247 let target_c = Field::new("c", DataType::Utf8, true);
248 let target_field = Field::new(
249 "s",
250 DataType::Struct(
251 vec![Arc::new(target_a.clone()), Arc::new(target_c.clone())].into(),
252 ),
253 true,
254 );
255
256 let schema = make_schema(&input_field);
257 let struct_array = make_struct_array(
258 vec![Arc::new(source_a.clone()), Arc::new(source_b.clone())].into(),
259 vec![
260 Arc::new(Int32Array::from(vec![Some(1), None])) as ArrayRef,
261 Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")]))
262 as ArrayRef,
263 ],
264 );
265 let batch = RecordBatch::try_new(
266 Arc::clone(&schema),
267 vec![Arc::new(struct_array) as Arc<_>],
268 )?;
269
270 let column = Arc::new(Column::new_with_schema("s", schema.as_ref())?);
271 let expr = CastColumnExpr::new(
272 column,
273 Arc::new(input_field.clone()),
274 Arc::new(target_field.clone()),
275 None,
276 );
277
278 let result = expr.evaluate(&batch)?;
279 let ColumnarValue::Array(array) = result else {
280 panic!("expected array");
281 };
282 let struct_array = as_struct_array(array.as_ref())?;
283 let cast_a = as_int64_array(struct_array.column_by_name("a").unwrap().as_ref())?;
284 assert_eq!(cast_a.value(0), 1);
285 assert!(cast_a.is_null(1));
286
287 let cast_c = as_string_array(struct_array.column_by_name("c").unwrap().as_ref())?;
288 assert!(cast_c.is_null(0));
289 assert!(cast_c.is_null(1));
290 Ok(())
291 }
292
293 #[test]
294 fn cast_nested_struct_array() -> DFResult<()> {
295 let inner_source = Field::new(
296 "inner",
297 DataType::Struct(
298 vec![Arc::new(Field::new("x", DataType::Int32, true))].into(),
299 ),
300 true,
301 );
302 let outer_field = Field::new(
303 "root",
304 DataType::Struct(vec![Arc::new(inner_source.clone())].into()),
305 true,
306 );
307
308 let inner_target = Field::new(
309 "inner",
310 DataType::Struct(
311 vec![
312 Arc::new(Field::new("x", DataType::Int64, true)),
313 Arc::new(Field::new("y", DataType::Boolean, true)),
314 ]
315 .into(),
316 ),
317 true,
318 );
319 let target_field = Field::new(
320 "root",
321 DataType::Struct(vec![Arc::new(inner_target.clone())].into()),
322 true,
323 );
324
325 let schema = make_schema(&outer_field);
326
327 let inner_struct = make_struct_array(
328 vec![Arc::new(Field::new("x", DataType::Int32, true))].into(),
329 vec![Arc::new(Int32Array::from(vec![Some(7), None])) as ArrayRef],
330 );
331 let outer_struct = make_struct_array(
332 vec![Arc::new(inner_source.clone())].into(),
333 vec![Arc::new(inner_struct) as ArrayRef],
334 );
335 let batch = RecordBatch::try_new(
336 Arc::clone(&schema),
337 vec![Arc::new(outer_struct) as ArrayRef],
338 )?;
339
340 let column = Arc::new(Column::new_with_schema("root", schema.as_ref())?);
341 let expr = CastColumnExpr::new(
342 column,
343 Arc::new(outer_field.clone()),
344 Arc::new(target_field.clone()),
345 None,
346 );
347
348 let result = expr.evaluate(&batch)?;
349 let ColumnarValue::Array(array) = result else {
350 panic!("expected array");
351 };
352 let struct_array = as_struct_array(array.as_ref())?;
353 let inner =
354 as_struct_array(struct_array.column_by_name("inner").unwrap().as_ref())?;
355 let x = as_int64_array(inner.column_by_name("x").unwrap().as_ref())?;
356 assert_eq!(x.value(0), 7);
357 assert!(x.is_null(1));
358 let y = inner.column_by_name("y").unwrap();
359 let y = y
360 .as_any()
361 .downcast_ref::<BooleanArray>()
362 .expect("boolean array");
363 assert!(y.is_null(0));
364 assert!(y.is_null(1));
365 Ok(())
366 }
367
368 #[test]
369 fn cast_struct_scalar() -> DFResult<()> {
370 let source_field = Field::new("a", DataType::Int32, true);
371 let input_field = Field::new(
372 "s",
373 DataType::Struct(vec![Arc::new(source_field.clone())].into()),
374 true,
375 );
376 let target_field = Field::new(
377 "s",
378 DataType::Struct(
379 vec![Arc::new(Field::new("a", DataType::UInt8, true))].into(),
380 ),
381 true,
382 );
383
384 let schema = make_schema(&input_field);
385 let scalar_struct = StructArray::new(
386 vec![Arc::new(source_field.clone())].into(),
387 vec![Arc::new(Int32Array::from(vec![Some(9)])) as ArrayRef],
388 None,
389 );
390 let literal =
391 Arc::new(Literal::new(ScalarValue::Struct(Arc::new(scalar_struct))));
392 let expr = CastColumnExpr::new(
393 literal,
394 Arc::new(input_field.clone()),
395 Arc::new(target_field.clone()),
396 None,
397 );
398
399 let batch = RecordBatch::new_empty(Arc::clone(&schema));
400 let result = expr.evaluate(&batch)?;
401 let ColumnarValue::Scalar(ScalarValue::Struct(array)) = result else {
402 panic!("expected struct scalar");
403 };
404 let casted = array.column_by_name("a").unwrap();
405 let casted = as_uint8_array(casted.as_ref())?;
406 assert_eq!(casted.value(0), 9);
407 Ok(())
408 }
409}