datafusion_expr/
expr_schema.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::{Between, Expr, Like};
19use crate::expr::{
20    AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList,
21    InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction,
22    WindowFunctionParams,
23};
24use crate::type_coercion::functions::{
25    data_types_with_scalar_udf, fields_with_aggregate_udf, fields_with_window_udf,
26};
27use crate::udf::ReturnFieldArgs;
28use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition};
29use arrow::compute::can_cast_types;
30use arrow::datatypes::{DataType, Field, FieldRef};
31use datafusion_common::metadata::FieldMetadata;
32use datafusion_common::{
33    not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema,
34    Result, Spans, TableReference,
35};
36use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer;
37use datafusion_functions_window_common::field::WindowUDFFieldArgs;
38use std::sync::Arc;
39
40/// Trait to allow expr to typable with respect to a schema
41pub trait ExprSchemable {
42    /// Given a schema, return the type of the expr
43    fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType>;
44
45    /// Given a schema, return the nullability of the expr
46    fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool>;
47
48    /// Given a schema, return the expr's optional metadata
49    fn metadata(&self, schema: &dyn ExprSchema) -> Result<FieldMetadata>;
50
51    /// Convert to a field with respect to a schema
52    fn to_field(
53        &self,
54        input_schema: &dyn ExprSchema,
55    ) -> Result<(Option<TableReference>, Arc<Field>)>;
56
57    /// Cast to a type with respect to a schema
58    fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr>;
59
60    /// Given a schema, return the type and nullability of the expr
61    fn data_type_and_nullable(&self, schema: &dyn ExprSchema)
62        -> Result<(DataType, bool)>;
63}
64
65impl ExprSchemable for Expr {
66    /// Returns the [arrow::datatypes::DataType] of the expression
67    /// based on [ExprSchema]
68    ///
69    /// Note: [`DFSchema`] implements [ExprSchema].
70    ///
71    /// [`DFSchema`]: datafusion_common::DFSchema
72    ///
73    /// # Examples
74    ///
75    /// Get the type of an expression that adds 2 columns. Adding an Int32
76    /// and Float32 results in Float32 type
77    ///
78    /// ```
79    /// # use arrow::datatypes::{DataType, Field};
80    /// # use datafusion_common::DFSchema;
81    /// # use datafusion_expr::{col, ExprSchemable};
82    /// # use std::collections::HashMap;
83    ///
84    /// fn main() {
85    ///     let expr = col("c1") + col("c2");
86    ///     let schema = DFSchema::from_unqualified_fields(
87    ///         vec![
88    ///             Field::new("c1", DataType::Int32, true),
89    ///             Field::new("c2", DataType::Float32, true),
90    ///         ]
91    ///         .into(),
92    ///         HashMap::new(),
93    ///     )
94    ///     .unwrap();
95    ///     assert_eq!("Float32", format!("{}", expr.get_type(&schema).unwrap()));
96    /// }
97    /// ```
98    ///
99    /// # Errors
100    ///
101    /// This function errors when it is not possible to compute its
102    /// [arrow::datatypes::DataType].  This happens when e.g. the
103    /// expression refers to a column that does not exist in the
104    /// schema, or when the expression is incorrectly typed
105    /// (e.g. `[utf8] + [bool]`).
106    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
107    fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType> {
108        match self {
109            Expr::Alias(Alias { expr, name, .. }) => match &**expr {
110                Expr::Placeholder(Placeholder { field, .. }) => match &field {
111                    None => schema.data_type(&Column::from_name(name)).cloned(),
112                    Some(field) => Ok(field.data_type().clone()),
113                },
114                _ => expr.get_type(schema),
115            },
116            Expr::Negative(expr) => expr.get_type(schema),
117            Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
118            Expr::OuterReferenceColumn(field, _) => Ok(field.data_type().clone()),
119            Expr::ScalarVariable(ty, _) => Ok(ty.clone()),
120            Expr::Literal(l, _) => Ok(l.data_type()),
121            Expr::Case(case) => {
122                for (_, then_expr) in &case.when_then_expr {
123                    let then_type = then_expr.get_type(schema)?;
124                    if !then_type.is_null() {
125                        return Ok(then_type);
126                    }
127                }
128                case.else_expr
129                    .as_ref()
130                    .map_or(Ok(DataType::Null), |e| e.get_type(schema))
131            }
132            Expr::Cast(Cast { data_type, .. })
133            | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()),
134            Expr::Unnest(Unnest { expr }) => {
135                let arg_data_type = expr.get_type(schema)?;
136                // Unnest's output type is the inner type of the list
137                match arg_data_type {
138                    DataType::List(field)
139                    | DataType::LargeList(field)
140                    | DataType::FixedSizeList(field, _) => Ok(field.data_type().clone()),
141                    DataType::Struct(_) => Ok(arg_data_type),
142                    DataType::Null => {
143                        not_impl_err!("unnest() does not support null yet")
144                    }
145                    _ => {
146                        plan_err!(
147                            "unnest() can only be applied to array, struct and null"
148                        )
149                    }
150                }
151            }
152            Expr::ScalarFunction(_func) => {
153                let (return_type, _) = self.data_type_and_nullable(schema)?;
154                Ok(return_type)
155            }
156            Expr::WindowFunction(window_function) => self
157                .data_type_and_nullable_with_window_function(schema, window_function)
158                .map(|(return_type, _)| return_type),
159            Expr::AggregateFunction(AggregateFunction {
160                func,
161                params: AggregateFunctionParams { args, .. },
162            }) => {
163                let fields = args
164                    .iter()
165                    .map(|e| e.to_field(schema).map(|(_, f)| f))
166                    .collect::<Result<Vec<_>>>()?;
167                let new_fields = fields_with_aggregate_udf(&fields, func)
168                    .map_err(|err| {
169                        let data_types = fields
170                            .iter()
171                            .map(|f| f.data_type().clone())
172                            .collect::<Vec<_>>();
173                        plan_datafusion_err!(
174                            "{} {}",
175                            match err {
176                                DataFusionError::Plan(msg) => msg,
177                                err => err.to_string(),
178                            },
179                            utils::generate_signature_error_msg(
180                                func.name(),
181                                func.signature().clone(),
182                                &data_types
183                            )
184                        )
185                    })?
186                    .into_iter()
187                    .collect::<Vec<_>>();
188                Ok(func.return_field(&new_fields)?.data_type().clone())
189            }
190            Expr::Not(_)
191            | Expr::IsNull(_)
192            | Expr::Exists { .. }
193            | Expr::InSubquery(_)
194            | Expr::Between { .. }
195            | Expr::InList { .. }
196            | Expr::IsNotNull(_)
197            | Expr::IsTrue(_)
198            | Expr::IsFalse(_)
199            | Expr::IsUnknown(_)
200            | Expr::IsNotTrue(_)
201            | Expr::IsNotFalse(_)
202            | Expr::IsNotUnknown(_) => Ok(DataType::Boolean),
203            Expr::ScalarSubquery(subquery) => {
204                Ok(subquery.subquery.schema().field(0).data_type().clone())
205            }
206            Expr::BinaryExpr(BinaryExpr {
207                ref left,
208                ref right,
209                ref op,
210            }) => BinaryTypeCoercer::new(
211                &left.get_type(schema)?,
212                op,
213                &right.get_type(schema)?,
214            )
215            .get_result_type(),
216            Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean),
217            Expr::Placeholder(Placeholder { field, .. }) => {
218                if let Some(field) = field {
219                    Ok(field.data_type().clone())
220                } else {
221                    // If the placeholder's type hasn't been specified, treat it as
222                    // null (unspecified placeholders generate an error during planning)
223                    Ok(DataType::Null)
224                }
225            }
226            #[expect(deprecated)]
227            Expr::Wildcard { .. } => Ok(DataType::Null),
228            Expr::GroupingSet(_) => {
229                // Grouping sets do not really have a type and do not appear in projections
230                Ok(DataType::Null)
231            }
232        }
233    }
234
235    /// Returns the nullability of the expression based on [ExprSchema].
236    ///
237    /// Note: [`DFSchema`] implements [ExprSchema].
238    ///
239    /// [`DFSchema`]: datafusion_common::DFSchema
240    ///
241    /// # Errors
242    ///
243    /// This function errors when it is not possible to compute its
244    /// nullability.  This happens when the expression refers to a
245    /// column that does not exist in the schema.
246    fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool> {
247        match self {
248            Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::Negative(expr) => {
249                expr.nullable(input_schema)
250            }
251
252            Expr::InList(InList { expr, list, .. }) => {
253                // Avoid inspecting too many expressions.
254                const MAX_INSPECT_LIMIT: usize = 6;
255                // Stop if a nullable expression is found or an error occurs.
256                let has_nullable = std::iter::once(expr.as_ref())
257                    .chain(list)
258                    .take(MAX_INSPECT_LIMIT)
259                    .find_map(|e| {
260                        e.nullable(input_schema)
261                            .map(|nullable| if nullable { Some(()) } else { None })
262                            .transpose()
263                    })
264                    .transpose()?;
265                Ok(match has_nullable {
266                    // If a nullable subexpression is found, the result may also be nullable.
267                    Some(_) => true,
268                    // If the list is too long, we assume it is nullable.
269                    None if list.len() + 1 > MAX_INSPECT_LIMIT => true,
270                    // All the subexpressions are non-nullable, so the result must be non-nullable.
271                    _ => false,
272                })
273            }
274
275            Expr::Between(Between {
276                expr, low, high, ..
277            }) => Ok(expr.nullable(input_schema)?
278                || low.nullable(input_schema)?
279                || high.nullable(input_schema)?),
280
281            Expr::Column(c) => input_schema.nullable(c),
282            Expr::OuterReferenceColumn(field, _) => Ok(field.is_nullable()),
283            Expr::Literal(value, _) => Ok(value.is_null()),
284            Expr::Case(case) => {
285                // This expression is nullable if any of the input expressions are nullable
286                let then_nullable = case
287                    .when_then_expr
288                    .iter()
289                    .map(|(_, t)| t.nullable(input_schema))
290                    .collect::<Result<Vec<_>>>()?;
291                if then_nullable.contains(&true) {
292                    Ok(true)
293                } else if let Some(e) = &case.else_expr {
294                    e.nullable(input_schema)
295                } else {
296                    // CASE produces NULL if there is no `else` expr
297                    // (aka when none of the `when_then_exprs` match)
298                    Ok(true)
299                }
300            }
301            Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
302            Expr::ScalarFunction(_func) => {
303                let (_, nullable) = self.data_type_and_nullable(input_schema)?;
304                Ok(nullable)
305            }
306            Expr::AggregateFunction(AggregateFunction { func, .. }) => {
307                Ok(func.is_nullable())
308            }
309            Expr::WindowFunction(window_function) => self
310                .data_type_and_nullable_with_window_function(
311                    input_schema,
312                    window_function,
313                )
314                .map(|(_, nullable)| nullable),
315            Expr::Placeholder(Placeholder { id: _, field }) => {
316                Ok(field.as_ref().map(|f| f.is_nullable()).unwrap_or(true))
317            }
318            Expr::ScalarVariable(_, _) | Expr::TryCast { .. } | Expr::Unnest(_) => {
319                Ok(true)
320            }
321            Expr::IsNull(_)
322            | Expr::IsNotNull(_)
323            | Expr::IsTrue(_)
324            | Expr::IsFalse(_)
325            | Expr::IsUnknown(_)
326            | Expr::IsNotTrue(_)
327            | Expr::IsNotFalse(_)
328            | Expr::IsNotUnknown(_)
329            | Expr::Exists { .. } => Ok(false),
330            Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema),
331            Expr::ScalarSubquery(subquery) => {
332                Ok(subquery.subquery.schema().field(0).is_nullable())
333            }
334            Expr::BinaryExpr(BinaryExpr {
335                ref left,
336                ref right,
337                ..
338            }) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?),
339            Expr::Like(Like { expr, pattern, .. })
340            | Expr::SimilarTo(Like { expr, pattern, .. }) => {
341                Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?)
342            }
343            #[expect(deprecated)]
344            Expr::Wildcard { .. } => Ok(false),
345            Expr::GroupingSet(_) => {
346                // Grouping sets do not really have the concept of nullable and do not appear
347                // in projections
348                Ok(true)
349            }
350        }
351    }
352
353    fn metadata(&self, schema: &dyn ExprSchema) -> Result<FieldMetadata> {
354        self.to_field(schema)
355            .map(|(_, field)| FieldMetadata::from(field.metadata()))
356    }
357
358    /// Returns the datatype and nullability of the expression based on [ExprSchema].
359    ///
360    /// Note: [`DFSchema`] implements [ExprSchema].
361    ///
362    /// [`DFSchema`]: datafusion_common::DFSchema
363    ///
364    /// # Errors
365    ///
366    /// This function errors when it is not possible to compute its
367    /// datatype or nullability.
368    fn data_type_and_nullable(
369        &self,
370        schema: &dyn ExprSchema,
371    ) -> Result<(DataType, bool)> {
372        let field = self.to_field(schema)?.1;
373
374        Ok((field.data_type().clone(), field.is_nullable()))
375    }
376
377    /// Returns a [arrow::datatypes::Field] compatible with this expression.
378    ///
379    /// This function converts an expression into a field with appropriate metadata
380    /// and nullability based on the expression type and context. It is the primary
381    /// mechanism for determining field-level schemas.
382    ///
383    /// # Field Property Resolution
384    ///
385    /// For each expression, the following properties are determined:
386    ///
387    /// ## Data Type Resolution
388    /// - **Column references**: Data type from input schema field
389    /// - **Literals**: Data type inferred from literal value
390    /// - **Aliases**: Data type inherited from the underlying expression (the aliased expression)
391    /// - **Binary expressions**: Result type from type coercion rules
392    /// - **Boolean expressions**: Always a boolean type
393    /// - **Cast expressions**: Target data type from cast operation
394    /// - **Function calls**: Return type based on function signature and argument types
395    ///
396    /// ## Nullability Determination
397    /// - **Column references**: Inherit nullability from input schema field
398    /// - **Literals**: Nullable only if literal value is NULL
399    /// - **Aliases**: Inherit nullability from the underlying expression (the aliased expression)
400    /// - **Binary expressions**: Nullable if either operand is nullable
401    /// - **Boolean expressions**: Always non-nullable (IS NULL, EXISTS, etc.)
402    /// - **Cast expressions**: determined by the input expression's nullability rules
403    /// - **Function calls**: Based on function nullability rules and input nullability
404    ///
405    /// ## Metadata Handling
406    /// - **Column references**: Preserve original field metadata from input schema
407    /// - **Literals**: Use explicitly provided metadata, otherwise empty
408    /// - **Aliases**: Merge underlying expr metadata with alias-specific metadata, preferring the alias metadata
409    /// - **Binary expressions**: field metadata is empty
410    /// - **Boolean expressions**: field metadata is empty
411    /// - **Cast expressions**: determined by the input expression's field metadata handling
412    /// - **Scalar functions**: Generate metadata via function's [`return_field_from_args`] method,
413    ///   with the default implementation returning empty field metadata
414    /// - **Aggregate functions**: Generate metadata via function's [`return_field`] method,
415    ///   with the default implementation returning empty field metadata
416    /// - **Window functions**: field metadata is empty
417    ///
418    /// ## Table Reference Scoping
419    /// - Establishes proper qualified field references when columns belong to specific tables
420    /// - Maintains table context for accurate field resolution in multi-table scenarios
421    ///
422    /// So for example, a projected expression `col(c1) + col(c2)` is
423    /// placed in an output field **named** col("c1 + c2")
424    ///
425    /// [`return_field_from_args`]: crate::ScalarUDF::return_field_from_args
426    /// [`return_field`]: crate::AggregateUDF::return_field
427    fn to_field(
428        &self,
429        schema: &dyn ExprSchema,
430    ) -> Result<(Option<TableReference>, Arc<Field>)> {
431        let (relation, schema_name) = self.qualified_name();
432        #[allow(deprecated)]
433        let field = match self {
434            Expr::Alias(Alias {
435                expr,
436                name: _,
437                metadata,
438                ..
439            }) => {
440                let field = expr.to_field(schema).map(|(_, f)| f.as_ref().clone())?;
441
442                let mut combined_metadata = expr.metadata(schema)?;
443                if let Some(metadata) = metadata {
444                    combined_metadata.extend(metadata.clone());
445                }
446
447                Ok(Arc::new(combined_metadata.add_to_field(field)))
448            }
449            Expr::Negative(expr) => expr.to_field(schema).map(|(_, f)| f),
450            Expr::Column(c) => schema.field_from_column(c).map(|f| Arc::new(f.clone())),
451            Expr::OuterReferenceColumn(field, _) => {
452                Ok(Arc::new(field.as_ref().clone().with_name(&schema_name)))
453            }
454            Expr::ScalarVariable(ty, _) => {
455                Ok(Arc::new(Field::new(&schema_name, ty.clone(), true)))
456            }
457            Expr::Literal(l, metadata) => {
458                let mut field = Field::new(&schema_name, l.data_type(), l.is_null());
459                if let Some(metadata) = metadata {
460                    field = metadata.add_to_field(field);
461                }
462                Ok(Arc::new(field))
463            }
464            Expr::IsNull(_)
465            | Expr::IsNotNull(_)
466            | Expr::IsTrue(_)
467            | Expr::IsFalse(_)
468            | Expr::IsUnknown(_)
469            | Expr::IsNotTrue(_)
470            | Expr::IsNotFalse(_)
471            | Expr::IsNotUnknown(_)
472            | Expr::Exists { .. } => {
473                Ok(Arc::new(Field::new(&schema_name, DataType::Boolean, false)))
474            }
475            Expr::ScalarSubquery(subquery) => {
476                Ok(Arc::clone(&subquery.subquery.schema().fields()[0]))
477            }
478            Expr::BinaryExpr(BinaryExpr {
479                ref left,
480                ref right,
481                ref op,
482            }) => {
483                let (lhs_type, lhs_nullable) = left.data_type_and_nullable(schema)?;
484                let (rhs_type, rhs_nullable) = right.data_type_and_nullable(schema)?;
485                let mut coercer = BinaryTypeCoercer::new(&lhs_type, op, &rhs_type);
486                coercer.set_lhs_spans(left.spans().cloned().unwrap_or_default());
487                coercer.set_rhs_spans(right.spans().cloned().unwrap_or_default());
488                Ok(Arc::new(Field::new(
489                    &schema_name,
490                    coercer.get_result_type()?,
491                    lhs_nullable || rhs_nullable,
492                )))
493            }
494            Expr::WindowFunction(window_function) => {
495                let (dt, nullable) = self.data_type_and_nullable_with_window_function(
496                    schema,
497                    window_function,
498                )?;
499                Ok(Arc::new(Field::new(&schema_name, dt, nullable)))
500            }
501            Expr::AggregateFunction(aggregate_function) => {
502                let AggregateFunction {
503                    func,
504                    params: AggregateFunctionParams { args, .. },
505                    ..
506                } = aggregate_function;
507
508                let fields = args
509                    .iter()
510                    .map(|e| e.to_field(schema).map(|(_, f)| f))
511                    .collect::<Result<Vec<_>>>()?;
512                // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
513                let new_fields = fields_with_aggregate_udf(&fields, func)
514                    .map_err(|err| {
515                        let arg_types = fields
516                            .iter()
517                            .map(|f| f.data_type())
518                            .cloned()
519                            .collect::<Vec<_>>();
520                        plan_datafusion_err!(
521                            "{} {}",
522                            match err {
523                                DataFusionError::Plan(msg) => msg,
524                                err => err.to_string(),
525                            },
526                            utils::generate_signature_error_msg(
527                                func.name(),
528                                func.signature().clone(),
529                                &arg_types,
530                            )
531                        )
532                    })?
533                    .into_iter()
534                    .collect::<Vec<_>>();
535
536                func.return_field(&new_fields)
537            }
538            Expr::ScalarFunction(ScalarFunction { func, args }) => {
539                let (arg_types, fields): (Vec<DataType>, Vec<Arc<Field>>) = args
540                    .iter()
541                    .map(|e| e.to_field(schema).map(|(_, f)| f))
542                    .collect::<Result<Vec<_>>>()?
543                    .into_iter()
544                    .map(|f| (f.data_type().clone(), f))
545                    .unzip();
546                // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
547                let new_data_types = data_types_with_scalar_udf(&arg_types, func)
548                    .map_err(|err| {
549                        plan_datafusion_err!(
550                            "{} {}",
551                            match err {
552                                DataFusionError::Plan(msg) => msg,
553                                err => err.to_string(),
554                            },
555                            utils::generate_signature_error_msg(
556                                func.name(),
557                                func.signature().clone(),
558                                &arg_types,
559                            )
560                        )
561                    })?;
562                let new_fields = fields
563                    .into_iter()
564                    .zip(new_data_types)
565                    .map(|(f, d)| f.as_ref().clone().with_data_type(d))
566                    .map(Arc::new)
567                    .collect::<Vec<FieldRef>>();
568
569                let arguments = args
570                    .iter()
571                    .map(|e| match e {
572                        Expr::Literal(sv, _) => Some(sv),
573                        _ => None,
574                    })
575                    .collect::<Vec<_>>();
576                let args = ReturnFieldArgs {
577                    arg_fields: &new_fields,
578                    scalar_arguments: &arguments,
579                };
580
581                func.return_field_from_args(args)
582            }
583            // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)),
584            Expr::Cast(Cast { expr, data_type }) => expr
585                .to_field(schema)
586                .map(|(_, f)| f.as_ref().clone().with_data_type(data_type.clone()))
587                .map(Arc::new),
588            Expr::Placeholder(Placeholder {
589                id: _,
590                field: Some(field),
591            }) => Ok(field.as_ref().clone().with_name(&schema_name).into()),
592            Expr::Like(_)
593            | Expr::SimilarTo(_)
594            | Expr::Not(_)
595            | Expr::Between(_)
596            | Expr::Case(_)
597            | Expr::TryCast(_)
598            | Expr::InList(_)
599            | Expr::InSubquery(_)
600            | Expr::Wildcard { .. }
601            | Expr::GroupingSet(_)
602            | Expr::Placeholder(_)
603            | Expr::Unnest(_) => Ok(Arc::new(Field::new(
604                &schema_name,
605                self.get_type(schema)?,
606                self.nullable(schema)?,
607            ))),
608        }?;
609
610        Ok((
611            relation,
612            Arc::new(field.as_ref().clone().with_name(schema_name)),
613        ))
614    }
615
616    /// Wraps this expression in a cast to a target [arrow::datatypes::DataType].
617    ///
618    /// # Errors
619    ///
620    /// This function errors when it is impossible to cast the
621    /// expression to the target [arrow::datatypes::DataType].
622    fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr> {
623        let this_type = self.get_type(schema)?;
624        if this_type == *cast_to_type {
625            return Ok(self);
626        }
627
628        // TODO(kszucs): Most of the operations do not validate the type correctness
629        // like all of the binary expressions below. Perhaps Expr should track the
630        // type of the expression?
631
632        if can_cast_types(&this_type, cast_to_type) {
633            match self {
634                Expr::ScalarSubquery(subquery) => {
635                    Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?))
636                }
637                _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))),
638            }
639        } else {
640            plan_err!("Cannot automatically convert {this_type} to {cast_to_type}")
641        }
642    }
643}
644
645impl Expr {
646    /// Common method for window functions that applies type coercion
647    /// to all arguments of the window function to check if it matches
648    /// its signature.
649    ///
650    /// If successful, this method returns the data type and
651    /// nullability of the window function's result.
652    ///
653    /// Otherwise, returns an error if there's a type mismatch between
654    /// the window function's signature and the provided arguments.
655    fn data_type_and_nullable_with_window_function(
656        &self,
657        schema: &dyn ExprSchema,
658        window_function: &WindowFunction,
659    ) -> Result<(DataType, bool)> {
660        let WindowFunction {
661            fun,
662            params: WindowFunctionParams { args, .. },
663            ..
664        } = window_function;
665
666        let fields = args
667            .iter()
668            .map(|e| e.to_field(schema).map(|(_, f)| f))
669            .collect::<Result<Vec<_>>>()?;
670        match fun {
671            WindowFunctionDefinition::AggregateUDF(udaf) => {
672                let data_types = fields
673                    .iter()
674                    .map(|f| f.data_type())
675                    .cloned()
676                    .collect::<Vec<_>>();
677                let new_fields = fields_with_aggregate_udf(&fields, udaf)
678                    .map_err(|err| {
679                        plan_datafusion_err!(
680                            "{} {}",
681                            match err {
682                                DataFusionError::Plan(msg) => msg,
683                                err => err.to_string(),
684                            },
685                            utils::generate_signature_error_msg(
686                                fun.name(),
687                                fun.signature(),
688                                &data_types
689                            )
690                        )
691                    })?
692                    .into_iter()
693                    .collect::<Vec<_>>();
694
695                let return_field = udaf.return_field(&new_fields)?;
696
697                Ok((return_field.data_type().clone(), return_field.is_nullable()))
698            }
699            WindowFunctionDefinition::WindowUDF(udwf) => {
700                let data_types = fields
701                    .iter()
702                    .map(|f| f.data_type())
703                    .cloned()
704                    .collect::<Vec<_>>();
705                let new_fields = fields_with_window_udf(&fields, udwf)
706                    .map_err(|err| {
707                        plan_datafusion_err!(
708                            "{} {}",
709                            match err {
710                                DataFusionError::Plan(msg) => msg,
711                                err => err.to_string(),
712                            },
713                            utils::generate_signature_error_msg(
714                                fun.name(),
715                                fun.signature(),
716                                &data_types
717                            )
718                        )
719                    })?
720                    .into_iter()
721                    .collect::<Vec<_>>();
722                let (_, function_name) = self.qualified_name();
723                let field_args = WindowUDFFieldArgs::new(&new_fields, &function_name);
724
725                udwf.field(field_args)
726                    .map(|field| (field.data_type().clone(), field.is_nullable()))
727            }
728        }
729    }
730}
731
732/// Cast subquery in InSubquery/ScalarSubquery to a given type.
733///
734/// 1. **Projection plan**: If the subquery is a projection (i.e. a SELECT statement with specific
735///    columns), it casts the first expression in the projection to the target type and creates a
736///    new projection with the casted expression.
737/// 2. **Non-projection plan**: If the subquery isn't a projection, it adds a projection to the plan
738///    with the casted first column.
739pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery> {
740    if subquery.subquery.schema().field(0).data_type() == cast_to_type {
741        return Ok(subquery);
742    }
743
744    let plan = subquery.subquery.as_ref();
745    let new_plan = match plan {
746        LogicalPlan::Projection(projection) => {
747            let cast_expr = projection.expr[0]
748                .clone()
749                .cast_to(cast_to_type, projection.input.schema())?;
750            LogicalPlan::Projection(Projection::try_new(
751                vec![cast_expr],
752                Arc::clone(&projection.input),
753            )?)
754        }
755        _ => {
756            let cast_expr = Expr::Column(Column::from(plan.schema().qualified_field(0)))
757                .cast_to(cast_to_type, subquery.subquery.schema())?;
758            LogicalPlan::Projection(Projection::try_new(
759                vec![cast_expr],
760                subquery.subquery,
761            )?)
762        }
763    };
764    Ok(Subquery {
765        subquery: Arc::new(new_plan),
766        outer_ref_columns: subquery.outer_ref_columns,
767        spans: Spans::new(),
768    })
769}
770
771#[cfg(test)]
772mod tests {
773    use std::collections::HashMap;
774
775    use super::*;
776    use crate::{col, lit, out_ref_col_with_metadata};
777
778    use datafusion_common::{internal_err, DFSchema, ScalarValue};
779
780    macro_rules! test_is_expr_nullable {
781        ($EXPR_TYPE:ident) => {{
782            let expr = lit(ScalarValue::Null).$EXPR_TYPE();
783            assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
784        }};
785    }
786
787    #[test]
788    fn expr_schema_nullability() {
789        let expr = col("foo").eq(lit(1));
790        assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
791        assert!(expr
792            .nullable(&MockExprSchema::new().with_nullable(true))
793            .unwrap());
794
795        test_is_expr_nullable!(is_null);
796        test_is_expr_nullable!(is_not_null);
797        test_is_expr_nullable!(is_true);
798        test_is_expr_nullable!(is_not_true);
799        test_is_expr_nullable!(is_false);
800        test_is_expr_nullable!(is_not_false);
801        test_is_expr_nullable!(is_unknown);
802        test_is_expr_nullable!(is_not_unknown);
803    }
804
805    #[test]
806    fn test_between_nullability() {
807        let get_schema = |nullable| {
808            MockExprSchema::new()
809                .with_data_type(DataType::Int32)
810                .with_nullable(nullable)
811        };
812
813        let expr = col("foo").between(lit(1), lit(2));
814        assert!(!expr.nullable(&get_schema(false)).unwrap());
815        assert!(expr.nullable(&get_schema(true)).unwrap());
816
817        let null = lit(ScalarValue::Int32(None));
818
819        let expr = col("foo").between(null.clone(), lit(2));
820        assert!(expr.nullable(&get_schema(false)).unwrap());
821
822        let expr = col("foo").between(lit(1), null.clone());
823        assert!(expr.nullable(&get_schema(false)).unwrap());
824
825        let expr = col("foo").between(null.clone(), null);
826        assert!(expr.nullable(&get_schema(false)).unwrap());
827    }
828
829    #[test]
830    fn test_inlist_nullability() {
831        let get_schema = |nullable| {
832            MockExprSchema::new()
833                .with_data_type(DataType::Int32)
834                .with_nullable(nullable)
835        };
836
837        let expr = col("foo").in_list(vec![lit(1); 5], false);
838        assert!(!expr.nullable(&get_schema(false)).unwrap());
839        assert!(expr.nullable(&get_schema(true)).unwrap());
840        // Testing nullable() returns an error.
841        assert!(expr
842            .nullable(&get_schema(false).with_error_on_nullable(true))
843            .is_err());
844
845        let null = lit(ScalarValue::Int32(None));
846        let expr = col("foo").in_list(vec![null, lit(1)], false);
847        assert!(expr.nullable(&get_schema(false)).unwrap());
848
849        // Testing on long list
850        let expr = col("foo").in_list(vec![lit(1); 6], false);
851        assert!(expr.nullable(&get_schema(false)).unwrap());
852    }
853
854    #[test]
855    fn test_like_nullability() {
856        let get_schema = |nullable| {
857            MockExprSchema::new()
858                .with_data_type(DataType::Utf8)
859                .with_nullable(nullable)
860        };
861
862        let expr = col("foo").like(lit("bar"));
863        assert!(!expr.nullable(&get_schema(false)).unwrap());
864        assert!(expr.nullable(&get_schema(true)).unwrap());
865
866        let expr = col("foo").like(lit(ScalarValue::Utf8(None)));
867        assert!(expr.nullable(&get_schema(false)).unwrap());
868    }
869
870    #[test]
871    fn expr_schema_data_type() {
872        let expr = col("foo");
873        assert_eq!(
874            DataType::Utf8,
875            expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8))
876                .unwrap()
877        );
878    }
879
880    #[test]
881    fn test_expr_metadata() {
882        let mut meta = HashMap::new();
883        meta.insert("bar".to_string(), "buzz".to_string());
884        let meta = FieldMetadata::from(meta);
885        let expr = col("foo");
886        let schema = MockExprSchema::new()
887            .with_data_type(DataType::Int32)
888            .with_metadata(meta.clone());
889
890        // col, alias, and cast should be metadata-preserving
891        assert_eq!(meta, expr.metadata(&schema).unwrap());
892        assert_eq!(meta, expr.clone().alias("bar").metadata(&schema).unwrap());
893        assert_eq!(
894            meta,
895            expr.clone()
896                .cast_to(&DataType::Int64, &schema)
897                .unwrap()
898                .metadata(&schema)
899                .unwrap()
900        );
901
902        let schema = DFSchema::from_unqualified_fields(
903            vec![meta.add_to_field(Field::new("foo", DataType::Int32, true))].into(),
904            HashMap::new(),
905        )
906        .unwrap();
907
908        // verify to_field method populates metadata
909        assert_eq!(meta, expr.metadata(&schema).unwrap());
910
911        // outer ref constructed by `out_ref_col_with_metadata` should be metadata-preserving
912        let outer_ref = out_ref_col_with_metadata(
913            DataType::Int32,
914            meta.to_hashmap(),
915            Column::from_name("foo"),
916        );
917        assert_eq!(meta, outer_ref.metadata(&schema).unwrap());
918    }
919
920    #[test]
921    fn test_expr_placeholder() {
922        let schema = MockExprSchema::new();
923
924        let mut placeholder_meta = HashMap::new();
925        placeholder_meta.insert("bar".to_string(), "buzz".to_string());
926        let placeholder_meta = FieldMetadata::from(placeholder_meta);
927
928        let expr = Expr::Placeholder(Placeholder::new_with_field(
929            "".to_string(),
930            Some(
931                Field::new("", DataType::Utf8, true)
932                    .with_metadata(placeholder_meta.to_hashmap())
933                    .into(),
934            ),
935        ));
936
937        assert_eq!(
938            expr.data_type_and_nullable(&schema).unwrap(),
939            (DataType::Utf8, true)
940        );
941        assert_eq!(placeholder_meta, expr.metadata(&schema).unwrap());
942
943        let expr_alias = expr.alias("a placeholder by any other name");
944        assert_eq!(
945            expr_alias.data_type_and_nullable(&schema).unwrap(),
946            (DataType::Utf8, true)
947        );
948        assert_eq!(placeholder_meta, expr_alias.metadata(&schema).unwrap());
949
950        // Non-nullable placeholder field should remain non-nullable
951        let expr = Expr::Placeholder(Placeholder::new_with_field(
952            "".to_string(),
953            Some(Field::new("", DataType::Utf8, false).into()),
954        ));
955        assert_eq!(
956            expr.data_type_and_nullable(&schema).unwrap(),
957            (DataType::Utf8, false)
958        );
959        let expr_alias = expr.alias("a placeholder by any other name");
960        assert_eq!(
961            expr_alias.data_type_and_nullable(&schema).unwrap(),
962            (DataType::Utf8, false)
963        );
964    }
965
966    #[derive(Debug)]
967    struct MockExprSchema {
968        field: Field,
969        error_on_nullable: bool,
970    }
971
972    impl MockExprSchema {
973        fn new() -> Self {
974            Self {
975                field: Field::new("mock_field", DataType::Null, false),
976                error_on_nullable: false,
977            }
978        }
979
980        fn with_nullable(mut self, nullable: bool) -> Self {
981            self.field = self.field.with_nullable(nullable);
982            self
983        }
984
985        fn with_data_type(mut self, data_type: DataType) -> Self {
986            self.field = self.field.with_data_type(data_type);
987            self
988        }
989
990        fn with_error_on_nullable(mut self, error_on_nullable: bool) -> Self {
991            self.error_on_nullable = error_on_nullable;
992            self
993        }
994
995        fn with_metadata(mut self, metadata: FieldMetadata) -> Self {
996            self.field = metadata.add_to_field(self.field);
997            self
998        }
999    }
1000
1001    impl ExprSchema for MockExprSchema {
1002        fn nullable(&self, _col: &Column) -> Result<bool> {
1003            if self.error_on_nullable {
1004                internal_err!("nullable error")
1005            } else {
1006                Ok(self.field.is_nullable())
1007            }
1008        }
1009
1010        fn field_from_column(&self, _col: &Column) -> Result<&Field> {
1011            Ok(&self.field)
1012        }
1013    }
1014}