Skip to main content

datafusion_expr/
expr_schema.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::{Between, Expr, Like, predicate_bounds};
19use crate::expr::{
20    AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList,
21    InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction,
22    WindowFunctionParams,
23};
24use crate::type_coercion::functions::{UDFCoercionExt, fields_with_udf};
25use crate::udf::ReturnFieldArgs;
26use crate::{LogicalPlan, Projection, Subquery, WindowFunctionDefinition, utils};
27use arrow::compute::can_cast_types;
28use arrow::datatypes::{DataType, Field, FieldRef};
29use datafusion_common::datatype::FieldExt;
30use datafusion_common::metadata::FieldMetadata;
31use datafusion_common::{
32    Column, DataFusionError, ExprSchema, Result, ScalarValue, Spans, TableReference,
33    not_impl_err, plan_datafusion_err, plan_err,
34};
35use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer;
36use datafusion_functions_window_common::field::WindowUDFFieldArgs;
37use std::sync::Arc;
38
39/// Trait to allow expr to typable with respect to a schema
40pub trait ExprSchemable {
41    /// Given a schema, return the type of the expr
42    fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType>;
43
44    /// Given a schema, return the nullability of the expr
45    fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool>;
46
47    /// Given a schema, return the expr's optional metadata
48    fn metadata(&self, schema: &dyn ExprSchema) -> Result<FieldMetadata>;
49
50    /// Convert to a field with respect to a schema
51    fn to_field(
52        &self,
53        input_schema: &dyn ExprSchema,
54    ) -> Result<(Option<TableReference>, Arc<Field>)>;
55
56    /// Cast to a type with respect to a schema
57    fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr>;
58
59    /// Given a schema, return the type and nullability of the expr
60    #[deprecated(
61        since = "51.0.0",
62        note = "Use `to_field().1.is_nullable` and `to_field().1.data_type()` directly instead"
63    )]
64    fn data_type_and_nullable(&self, schema: &dyn ExprSchema)
65    -> Result<(DataType, bool)>;
66}
67
68impl ExprSchemable for Expr {
69    /// Returns the [arrow::datatypes::DataType] of the expression
70    /// based on [ExprSchema]
71    ///
72    /// Note: [`DFSchema`] implements [ExprSchema].
73    ///
74    /// [`DFSchema`]: datafusion_common::DFSchema
75    ///
76    /// # Examples
77    ///
78    /// Get the type of an expression that adds 2 columns. Adding an Int32
79    /// and Float32 results in Float32 type
80    ///
81    /// ```
82    /// # use arrow::datatypes::{DataType, Field};
83    /// # use datafusion_common::DFSchema;
84    /// # use datafusion_expr::{col, ExprSchemable};
85    /// # use std::collections::HashMap;
86    ///
87    /// fn main() {
88    ///     let expr = col("c1") + col("c2");
89    ///     let schema = DFSchema::from_unqualified_fields(
90    ///         vec![
91    ///             Field::new("c1", DataType::Int32, true),
92    ///             Field::new("c2", DataType::Float32, true),
93    ///         ]
94    ///         .into(),
95    ///         HashMap::new(),
96    ///     )
97    ///     .unwrap();
98    ///     assert_eq!("Float32", format!("{}", expr.get_type(&schema).unwrap()));
99    /// }
100    /// ```
101    ///
102    /// # Errors
103    ///
104    /// This function errors when it is not possible to compute its
105    /// [arrow::datatypes::DataType].  This happens when e.g. the
106    /// expression refers to a column that does not exist in the
107    /// schema, or when the expression is incorrectly typed
108    /// (e.g. `[utf8] + [bool]`).
109    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
110    fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType> {
111        match self {
112            Expr::Alias(Alias { expr, name, .. }) => match &**expr {
113                Expr::Placeholder(Placeholder { field, .. }) => match &field {
114                    None => schema.data_type(&Column::from_name(name)).cloned(),
115                    Some(field) => Ok(field.data_type().clone()),
116                },
117                _ => expr.get_type(schema),
118            },
119            Expr::Negative(expr) => expr.get_type(schema),
120            Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
121            Expr::OuterReferenceColumn(field, _) => Ok(field.data_type().clone()),
122            Expr::ScalarVariable(field, _) => Ok(field.data_type().clone()),
123            Expr::Literal(l, _) => Ok(l.data_type()),
124            Expr::Case(case) => {
125                for (_, then_expr) in &case.when_then_expr {
126                    let then_type = then_expr.get_type(schema)?;
127                    if !then_type.is_null() {
128                        return Ok(then_type);
129                    }
130                }
131                case.else_expr
132                    .as_ref()
133                    .map_or(Ok(DataType::Null), |e| e.get_type(schema))
134            }
135            Expr::Cast(Cast { data_type, .. })
136            | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()),
137            Expr::Unnest(Unnest { expr }) => {
138                let arg_data_type = expr.get_type(schema)?;
139                // Unnest's output type is the inner type of the list
140                match arg_data_type {
141                    DataType::List(field)
142                    | DataType::LargeList(field)
143                    | DataType::FixedSizeList(field, _) => Ok(field.data_type().clone()),
144                    DataType::Struct(_) => Ok(arg_data_type),
145                    DataType::Null => {
146                        not_impl_err!("unnest() does not support null yet")
147                    }
148                    _ => {
149                        plan_err!(
150                            "unnest() can only be applied to array, struct and null"
151                        )
152                    }
153                }
154            }
155            Expr::ScalarFunction(_)
156            | Expr::WindowFunction(_)
157            | Expr::AggregateFunction(_) => {
158                Ok(self.to_field(schema)?.1.data_type().clone())
159            }
160            Expr::Not(_)
161            | Expr::IsNull(_)
162            | Expr::Exists { .. }
163            | Expr::InSubquery(_)
164            | Expr::SetComparison(_)
165            | Expr::Between { .. }
166            | Expr::InList { .. }
167            | Expr::IsNotNull(_)
168            | Expr::IsTrue(_)
169            | Expr::IsFalse(_)
170            | Expr::IsUnknown(_)
171            | Expr::IsNotTrue(_)
172            | Expr::IsNotFalse(_)
173            | Expr::IsNotUnknown(_) => Ok(DataType::Boolean),
174            Expr::ScalarSubquery(subquery) => {
175                Ok(subquery.subquery.schema().field(0).data_type().clone())
176            }
177            Expr::BinaryExpr(BinaryExpr { left, right, op }) => BinaryTypeCoercer::new(
178                &left.get_type(schema)?,
179                op,
180                &right.get_type(schema)?,
181            )
182            .get_result_type(),
183            Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean),
184            Expr::Placeholder(Placeholder { field, .. }) => {
185                if let Some(field) = field {
186                    Ok(field.data_type().clone())
187                } else {
188                    // If the placeholder's type hasn't been specified, treat it as
189                    // null (unspecified placeholders generate an error during planning)
190                    Ok(DataType::Null)
191                }
192            }
193            #[expect(deprecated)]
194            Expr::Wildcard { .. } => Ok(DataType::Null),
195            Expr::GroupingSet(_) => {
196                // Grouping sets do not really have a type and do not appear in projections
197                Ok(DataType::Null)
198            }
199        }
200    }
201
202    /// Returns the nullability of the expression based on [ExprSchema].
203    ///
204    /// Note: [`DFSchema`] implements [ExprSchema].
205    ///
206    /// [`DFSchema`]: datafusion_common::DFSchema
207    ///
208    /// # Errors
209    ///
210    /// This function errors when it is not possible to compute its
211    /// nullability.  This happens when the expression refers to a
212    /// column that does not exist in the schema.
213    fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool> {
214        match self {
215            Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::Negative(expr) => {
216                expr.nullable(input_schema)
217            }
218
219            Expr::InList(InList { expr, list, .. }) => {
220                // Avoid inspecting too many expressions.
221                const MAX_INSPECT_LIMIT: usize = 6;
222                // Stop if a nullable expression is found or an error occurs.
223                let has_nullable = std::iter::once(expr.as_ref())
224                    .chain(list)
225                    .take(MAX_INSPECT_LIMIT)
226                    .find_map(|e| {
227                        e.nullable(input_schema)
228                            .map(|nullable| if nullable { Some(()) } else { None })
229                            .transpose()
230                    })
231                    .transpose()?;
232                Ok(match has_nullable {
233                    // If a nullable subexpression is found, the result may also be nullable.
234                    Some(_) => true,
235                    // If the list is too long, we assume it is nullable.
236                    None if list.len() + 1 > MAX_INSPECT_LIMIT => true,
237                    // All the subexpressions are non-nullable, so the result must be non-nullable.
238                    _ => false,
239                })
240            }
241
242            Expr::Between(Between {
243                expr, low, high, ..
244            }) => Ok(expr.nullable(input_schema)?
245                || low.nullable(input_schema)?
246                || high.nullable(input_schema)?),
247
248            Expr::Column(c) => input_schema.nullable(c),
249            Expr::OuterReferenceColumn(field, _) => Ok(field.is_nullable()),
250            Expr::Literal(value, _) => Ok(value.is_null()),
251            Expr::Case(case) => {
252                let nullable_then = case
253                    .when_then_expr
254                    .iter()
255                    .filter_map(|(w, t)| {
256                        let is_nullable = match t.nullable(input_schema) {
257                            Err(e) => return Some(Err(e)),
258                            Ok(n) => n,
259                        };
260
261                        // Branches with a then expression that is not nullable do not impact the
262                        // nullability of the case expression.
263                        if !is_nullable {
264                            return None;
265                        }
266
267                        // For case-with-expression assume all 'then' expressions are reachable
268                        if case.expr.is_some() {
269                            return Some(Ok(()));
270                        }
271
272                        // For branches with a nullable 'then' expression, try to determine
273                        // if the 'then' expression is ever reachable in the situation where
274                        // it would evaluate to null.
275                        let bounds = match predicate_bounds::evaluate_bounds(
276                            w,
277                            Some(unwrap_certainly_null_expr(t)),
278                            input_schema,
279                        ) {
280                            Err(e) => return Some(Err(e)),
281                            Ok(b) => b,
282                        };
283
284                        let can_be_true = match bounds
285                            .contains_value(ScalarValue::Boolean(Some(true)))
286                        {
287                            Err(e) => return Some(Err(e)),
288                            Ok(b) => b,
289                        };
290
291                        if !can_be_true {
292                            // If the derived 'when' expression can never evaluate to true, the
293                            // 'then' expression is not reachable when it would evaluate to NULL.
294                            // The most common pattern for this is `WHEN x IS NOT NULL THEN x`.
295                            None
296                        } else {
297                            // The branch might be taken
298                            Some(Ok(()))
299                        }
300                    })
301                    .next();
302
303                if let Some(nullable_then) = nullable_then {
304                    // There is at least one reachable nullable 'then' expression, so the case
305                    // expression itself is nullable.
306                    // Use `Result::map` to propagate the error from `nullable_then` if there is one.
307                    nullable_then.map(|_| true)
308                } else if let Some(e) = &case.else_expr {
309                    // There are no reachable nullable 'then' expressions, so all we still need to
310                    // check is the 'else' expression's nullability.
311                    e.nullable(input_schema)
312                } else {
313                    // CASE produces NULL if there is no `else` expr
314                    // (aka when none of the `when_then_exprs` match)
315                    Ok(true)
316                }
317            }
318            Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
319            Expr::ScalarFunction(_)
320            | Expr::AggregateFunction(_)
321            | Expr::WindowFunction(_) => Ok(self.to_field(input_schema)?.1.is_nullable()),
322            Expr::ScalarVariable(field, _) => Ok(field.is_nullable()),
323            Expr::TryCast { .. } | Expr::Unnest(_) | Expr::Placeholder(_) => Ok(true),
324            Expr::IsNull(_)
325            | Expr::IsNotNull(_)
326            | Expr::IsTrue(_)
327            | Expr::IsFalse(_)
328            | Expr::IsUnknown(_)
329            | Expr::IsNotTrue(_)
330            | Expr::IsNotFalse(_)
331            | Expr::IsNotUnknown(_)
332            | Expr::Exists { .. } => Ok(false),
333            Expr::SetComparison(_) => Ok(true),
334            Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema),
335            Expr::ScalarSubquery(subquery) => {
336                Ok(subquery.subquery.schema().field(0).is_nullable())
337            }
338            Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
339                Ok(left.nullable(input_schema)? || right.nullable(input_schema)?)
340            }
341            Expr::Like(Like { expr, pattern, .. })
342            | Expr::SimilarTo(Like { expr, pattern, .. }) => {
343                Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?)
344            }
345            #[expect(deprecated)]
346            Expr::Wildcard { .. } => Ok(false),
347            Expr::GroupingSet(_) => {
348                // Grouping sets do not really have the concept of nullable and do not appear
349                // in projections
350                Ok(true)
351            }
352        }
353    }
354
355    fn metadata(&self, schema: &dyn ExprSchema) -> Result<FieldMetadata> {
356        self.to_field(schema)
357            .map(|(_, field)| FieldMetadata::from(field.metadata()))
358    }
359
360    /// Returns the datatype and nullability of the expression based on [ExprSchema].
361    ///
362    /// Note: [`DFSchema`] implements [ExprSchema].
363    ///
364    /// [`DFSchema`]: datafusion_common::DFSchema
365    ///
366    /// # Errors
367    ///
368    /// This function errors when it is not possible to compute its
369    /// datatype or nullability.
370    fn data_type_and_nullable(
371        &self,
372        schema: &dyn ExprSchema,
373    ) -> Result<(DataType, bool)> {
374        let field = self.to_field(schema)?.1;
375
376        Ok((field.data_type().clone(), field.is_nullable()))
377    }
378
379    /// Returns a [arrow::datatypes::Field] compatible with this expression.
380    ///
381    /// This function converts an expression into a field with appropriate metadata
382    /// and nullability based on the expression type and context. It is the primary
383    /// mechanism for determining field-level schemas.
384    ///
385    /// # Field Property Resolution
386    ///
387    /// For each expression, the following properties are determined:
388    ///
389    /// ## Data Type Resolution
390    /// - **Column references**: Data type from input schema field
391    /// - **Literals**: Data type inferred from literal value
392    /// - **Aliases**: Data type inherited from the underlying expression (the aliased expression)
393    /// - **Binary expressions**: Result type from type coercion rules
394    /// - **Boolean expressions**: Always a boolean type
395    /// - **Cast expressions**: Target data type from cast operation
396    /// - **Function calls**: Return type based on function signature and argument types
397    ///
398    /// ## Nullability Determination
399    /// - **Column references**: Inherit nullability from input schema field
400    /// - **Literals**: Nullable only if literal value is NULL
401    /// - **Aliases**: Inherit nullability from the underlying expression (the aliased expression)
402    /// - **Binary expressions**: Nullable if either operand is nullable
403    /// - **Boolean expressions**: Always non-nullable (IS NULL, EXISTS, etc.)
404    /// - **Cast expressions**: determined by the input expression's nullability rules
405    /// - **Function calls**: Based on function nullability rules and input nullability
406    ///
407    /// ## Metadata Handling
408    /// - **Column references**: Preserve original field metadata from input schema
409    /// - **Literals**: Use explicitly provided metadata, otherwise empty
410    /// - **Aliases**: Merge underlying expr metadata with alias-specific metadata, preferring the alias metadata
411    /// - **Binary expressions**: field metadata is empty
412    /// - **Boolean expressions**: field metadata is empty
413    /// - **Cast expressions**: determined by the input expression's field metadata handling
414    /// - **Scalar functions**: Generate metadata via function's [`return_field_from_args`] method,
415    ///   with the default implementation returning empty field metadata
416    /// - **Aggregate functions**: Generate metadata via function's [`return_field`] method,
417    ///   with the default implementation returning empty field metadata
418    /// - **Window functions**: field metadata follows the function's return field
419    ///
420    /// ## Table Reference Scoping
421    /// - Establishes proper qualified field references when columns belong to specific tables
422    /// - Maintains table context for accurate field resolution in multi-table scenarios
423    ///
424    /// So for example, a projected expression `col(c1) + col(c2)` is
425    /// placed in an output field **named** col("c1 + c2")
426    ///
427    /// [`return_field_from_args`]: crate::ScalarUDF::return_field_from_args
428    /// [`return_field`]: crate::AggregateUDF::return_field
429    fn to_field(
430        &self,
431        schema: &dyn ExprSchema,
432    ) -> Result<(Option<TableReference>, Arc<Field>)> {
433        let (relation, schema_name) = self.qualified_name();
434        #[expect(deprecated)]
435        let field = match self {
436            Expr::Alias(Alias {
437                expr,
438                name: _,
439                metadata,
440                ..
441            }) => {
442                let mut combined_metadata = expr.metadata(schema)?;
443                if let Some(metadata) = metadata {
444                    combined_metadata.extend(metadata.clone());
445                }
446
447                Ok(expr
448                    .to_field(schema)
449                    .map(|(_, f)| f)?
450                    .with_field_metadata(&combined_metadata))
451            }
452            Expr::Negative(expr) => expr.to_field(schema).map(|(_, f)| f),
453            Expr::Column(c) => schema.field_from_column(c).map(Arc::clone),
454            Expr::OuterReferenceColumn(field, _) => {
455                Ok(Arc::clone(field).renamed(&schema_name))
456            }
457            Expr::ScalarVariable(field, _) => Ok(Arc::clone(field).renamed(&schema_name)),
458            Expr::Literal(l, metadata) => Ok(Arc::new(
459                Field::new(&schema_name, l.data_type(), l.is_null())
460                    .with_field_metadata_opt(metadata.as_ref()),
461            )),
462            Expr::IsNull(_)
463            | Expr::IsNotNull(_)
464            | Expr::IsTrue(_)
465            | Expr::IsFalse(_)
466            | Expr::IsUnknown(_)
467            | Expr::IsNotTrue(_)
468            | Expr::IsNotFalse(_)
469            | Expr::IsNotUnknown(_)
470            | Expr::Exists { .. } => {
471                Ok(Arc::new(Field::new(&schema_name, DataType::Boolean, false)))
472            }
473            Expr::ScalarSubquery(subquery) => {
474                Ok(Arc::clone(&subquery.subquery.schema().fields()[0]))
475            }
476            Expr::BinaryExpr(BinaryExpr { left, right, op }) => {
477                let (left_field, right_field) =
478                    (left.to_field(schema)?.1, right.to_field(schema)?.1);
479
480                let (lhs_type, lhs_nullable) =
481                    (left_field.data_type(), left_field.is_nullable());
482                let (rhs_type, rhs_nullable) =
483                    (right_field.data_type(), right_field.is_nullable());
484                let mut coercer = BinaryTypeCoercer::new(lhs_type, op, rhs_type);
485                coercer.set_lhs_spans(left.spans().cloned().unwrap_or_default());
486                coercer.set_rhs_spans(right.spans().cloned().unwrap_or_default());
487                Ok(Arc::new(Field::new(
488                    &schema_name,
489                    coercer.get_result_type()?,
490                    lhs_nullable || rhs_nullable,
491                )))
492            }
493            Expr::WindowFunction(window_function) => {
494                let WindowFunction {
495                    fun,
496                    params: WindowFunctionParams { args, .. },
497                    ..
498                } = window_function.as_ref();
499
500                let fields = args
501                    .iter()
502                    .map(|e| e.to_field(schema).map(|(_, f)| f))
503                    .collect::<Result<Vec<_>>>()?;
504                match fun {
505                    WindowFunctionDefinition::AggregateUDF(udaf) => {
506                        let new_fields =
507                            verify_function_arguments(udaf.as_ref(), &fields)?;
508                        let return_field = udaf.return_field(&new_fields)?;
509                        Ok(return_field)
510                    }
511                    WindowFunctionDefinition::WindowUDF(udwf) => {
512                        let new_fields =
513                            verify_function_arguments(udwf.as_ref(), &fields)?;
514                        let return_field = udwf
515                            .field(WindowUDFFieldArgs::new(&new_fields, &schema_name))?;
516                        Ok(return_field)
517                    }
518                }
519            }
520            Expr::AggregateFunction(AggregateFunction {
521                func,
522                params: AggregateFunctionParams { args, .. },
523            }) => {
524                let fields = args
525                    .iter()
526                    .map(|e| e.to_field(schema).map(|(_, f)| f))
527                    .collect::<Result<Vec<_>>>()?;
528                let new_fields = verify_function_arguments(func.as_ref(), &fields)?;
529                func.return_field(&new_fields)
530            }
531            Expr::ScalarFunction(ScalarFunction { func, args }) => {
532                let fields = args
533                    .iter()
534                    .map(|e| e.to_field(schema).map(|(_, f)| f))
535                    .collect::<Result<Vec<_>>>()?;
536                let new_fields = verify_function_arguments(func.as_ref(), &fields)?;
537
538                let arguments = args
539                    .iter()
540                    .map(|e| match e {
541                        Expr::Literal(sv, _) => Some(sv),
542                        _ => None,
543                    })
544                    .collect::<Vec<_>>();
545                let args = ReturnFieldArgs {
546                    arg_fields: &new_fields,
547                    scalar_arguments: &arguments,
548                };
549
550                func.return_field_from_args(args)
551            }
552            // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)),
553            Expr::Cast(Cast { expr, data_type }) => expr
554                .to_field(schema)
555                .map(|(_, f)| f.retyped(data_type.clone())),
556            Expr::Placeholder(Placeholder {
557                id: _,
558                field: Some(field),
559            }) => Ok(Arc::clone(field).renamed(&schema_name)),
560            Expr::Like(_)
561            | Expr::SimilarTo(_)
562            | Expr::Not(_)
563            | Expr::Between(_)
564            | Expr::Case(_)
565            | Expr::TryCast(_)
566            | Expr::InList(_)
567            | Expr::InSubquery(_)
568            | Expr::SetComparison(_)
569            | Expr::Wildcard { .. }
570            | Expr::GroupingSet(_)
571            | Expr::Placeholder(_)
572            | Expr::Unnest(_) => Ok(Arc::new(Field::new(
573                &schema_name,
574                self.get_type(schema)?,
575                self.nullable(schema)?,
576            ))),
577        }?;
578
579        Ok((
580            relation,
581            // todo avoid this rename / use the name above
582            field.renamed(&schema_name),
583        ))
584    }
585
586    /// Wraps this expression in a cast to a target [arrow::datatypes::DataType].
587    ///
588    /// # Errors
589    ///
590    /// This function errors when it is impossible to cast the
591    /// expression to the target [arrow::datatypes::DataType].
592    fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr> {
593        let this_type = self.get_type(schema)?;
594        if this_type == *cast_to_type {
595            return Ok(self);
596        }
597
598        // TODO(kszucs): Most of the operations do not validate the type correctness
599        // like all of the binary expressions below. Perhaps Expr should track the
600        // type of the expression?
601
602        // Special handling for struct-to-struct casts with name-based field matching
603        let can_cast = match (&this_type, cast_to_type) {
604            (DataType::Struct(_), DataType::Struct(_)) => {
605                // Always allow struct-to-struct casts; field matching happens at runtime
606                true
607            }
608            _ => can_cast_types(&this_type, cast_to_type),
609        };
610
611        if can_cast {
612            match self {
613                Expr::ScalarSubquery(subquery) => {
614                    Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?))
615                }
616                _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))),
617            }
618        } else {
619            plan_err!("Cannot automatically convert {this_type} to {cast_to_type}")
620        }
621    }
622}
623
624/// Verify that function is invoked with correct number and type of arguments as
625/// defined in `TypeSignature`.
626fn verify_function_arguments<F: UDFCoercionExt>(
627    function: &F,
628    input_fields: &[FieldRef],
629) -> Result<Vec<FieldRef>> {
630    fields_with_udf(input_fields, function).map_err(|err| {
631        let data_types = input_fields
632            .iter()
633            .map(|f| f.data_type())
634            .cloned()
635            .collect::<Vec<_>>();
636        plan_datafusion_err!(
637            "{} {}",
638            match err {
639                DataFusionError::Plan(msg) => msg,
640                err => err.to_string(),
641            },
642            utils::generate_signature_error_message(
643                function.name(),
644                function.signature(),
645                &data_types
646            )
647        )
648    })
649}
650
651/// Returns the innermost [Expr] that is provably null if `expr` is null.
652fn unwrap_certainly_null_expr(expr: &Expr) -> &Expr {
653    match expr {
654        Expr::Not(e) => unwrap_certainly_null_expr(e),
655        Expr::Negative(e) => unwrap_certainly_null_expr(e),
656        Expr::Cast(e) => unwrap_certainly_null_expr(e.expr.as_ref()),
657        _ => expr,
658    }
659}
660
661/// Cast subquery in InSubquery/ScalarSubquery to a given type.
662///
663/// 1. **Projection plan**: If the subquery is a projection (i.e. a SELECT statement with specific
664///    columns), it casts the first expression in the projection to the target type and creates a
665///    new projection with the casted expression.
666/// 2. **Non-projection plan**: If the subquery isn't a projection, it adds a projection to the plan
667///    with the casted first column.
668pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery> {
669    if subquery.subquery.schema().field(0).data_type() == cast_to_type {
670        return Ok(subquery);
671    }
672
673    let plan = subquery.subquery.as_ref();
674    let new_plan = match plan {
675        LogicalPlan::Projection(projection) => {
676            let cast_expr = projection.expr[0]
677                .clone()
678                .cast_to(cast_to_type, projection.input.schema())?;
679            LogicalPlan::Projection(Projection::try_new(
680                vec![cast_expr],
681                Arc::clone(&projection.input),
682            )?)
683        }
684        _ => {
685            let cast_expr = Expr::Column(Column::from(plan.schema().qualified_field(0)))
686                .cast_to(cast_to_type, subquery.subquery.schema())?;
687            LogicalPlan::Projection(Projection::try_new(
688                vec![cast_expr],
689                subquery.subquery,
690            )?)
691        }
692    };
693    Ok(Subquery {
694        subquery: Arc::new(new_plan),
695        outer_ref_columns: subquery.outer_ref_columns,
696        spans: Spans::new(),
697    })
698}
699
700#[cfg(test)]
701mod tests {
702    use std::collections::HashMap;
703
704    use super::*;
705    use crate::{and, col, lit, not, or, out_ref_col_with_metadata, when};
706
707    use arrow::datatypes::FieldRef;
708    use datafusion_common::{DFSchema, ScalarValue, assert_or_internal_err};
709
710    macro_rules! test_is_expr_nullable {
711        ($EXPR_TYPE:ident) => {{
712            let expr = lit(ScalarValue::Null).$EXPR_TYPE();
713            assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
714        }};
715    }
716
717    #[test]
718    fn expr_schema_nullability() {
719        let expr = col("foo").eq(lit(1));
720        assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
721        assert!(
722            expr.nullable(&MockExprSchema::new().with_nullable(true))
723                .unwrap()
724        );
725
726        test_is_expr_nullable!(is_null);
727        test_is_expr_nullable!(is_not_null);
728        test_is_expr_nullable!(is_true);
729        test_is_expr_nullable!(is_not_true);
730        test_is_expr_nullable!(is_false);
731        test_is_expr_nullable!(is_not_false);
732        test_is_expr_nullable!(is_unknown);
733        test_is_expr_nullable!(is_not_unknown);
734    }
735
736    #[test]
737    fn test_between_nullability() {
738        let get_schema = |nullable| {
739            MockExprSchema::new()
740                .with_data_type(DataType::Int32)
741                .with_nullable(nullable)
742        };
743
744        let expr = col("foo").between(lit(1), lit(2));
745        assert!(!expr.nullable(&get_schema(false)).unwrap());
746        assert!(expr.nullable(&get_schema(true)).unwrap());
747
748        let null = lit(ScalarValue::Int32(None));
749
750        let expr = col("foo").between(null.clone(), lit(2));
751        assert!(expr.nullable(&get_schema(false)).unwrap());
752
753        let expr = col("foo").between(lit(1), null.clone());
754        assert!(expr.nullable(&get_schema(false)).unwrap());
755
756        let expr = col("foo").between(null.clone(), null);
757        assert!(expr.nullable(&get_schema(false)).unwrap());
758    }
759
760    fn assert_nullability(expr: &Expr, schema: &dyn ExprSchema, expected: bool) {
761        assert_eq!(
762            expr.nullable(schema).unwrap(),
763            expected,
764            "Nullability of '{expr}' should be {expected}"
765        );
766    }
767
768    fn assert_not_nullable(expr: &Expr, schema: &dyn ExprSchema) {
769        assert_nullability(expr, schema, false);
770    }
771
772    fn assert_nullable(expr: &Expr, schema: &dyn ExprSchema) {
773        assert_nullability(expr, schema, true);
774    }
775
776    #[test]
777    fn test_case_expression_nullability() -> Result<()> {
778        let nullable_schema = MockExprSchema::new()
779            .with_data_type(DataType::Int32)
780            .with_nullable(true);
781
782        let not_nullable_schema = MockExprSchema::new()
783            .with_data_type(DataType::Int32)
784            .with_nullable(false);
785
786        // CASE WHEN x IS NOT NULL THEN x ELSE 0
787        let e = when(col("x").is_not_null(), col("x")).otherwise(lit(0))?;
788        assert_not_nullable(&e, &nullable_schema);
789        assert_not_nullable(&e, &not_nullable_schema);
790
791        // CASE WHEN NOT x IS NULL THEN x ELSE 0
792        let e = when(not(col("x").is_null()), col("x")).otherwise(lit(0))?;
793        assert_not_nullable(&e, &nullable_schema);
794        assert_not_nullable(&e, &not_nullable_schema);
795
796        // CASE WHEN X = 5 THEN x ELSE 0
797        let e = when(col("x").eq(lit(5)), col("x")).otherwise(lit(0))?;
798        assert_not_nullable(&e, &nullable_schema);
799        assert_not_nullable(&e, &not_nullable_schema);
800
801        // CASE WHEN x IS NOT NULL AND x = 5 THEN x ELSE 0
802        let e = when(and(col("x").is_not_null(), col("x").eq(lit(5))), col("x"))
803            .otherwise(lit(0))?;
804        assert_not_nullable(&e, &nullable_schema);
805        assert_not_nullable(&e, &not_nullable_schema);
806
807        // CASE WHEN x = 5 AND x IS NOT NULL THEN x ELSE 0
808        let e = when(and(col("x").eq(lit(5)), col("x").is_not_null()), col("x"))
809            .otherwise(lit(0))?;
810        assert_not_nullable(&e, &nullable_schema);
811        assert_not_nullable(&e, &not_nullable_schema);
812
813        // CASE WHEN x IS NOT NULL OR x = 5 THEN x ELSE 0
814        let e = when(or(col("x").is_not_null(), col("x").eq(lit(5))), col("x"))
815            .otherwise(lit(0))?;
816        assert_not_nullable(&e, &nullable_schema);
817        assert_not_nullable(&e, &not_nullable_schema);
818
819        // CASE WHEN x = 5 OR x IS NOT NULL THEN x ELSE 0
820        let e = when(or(col("x").eq(lit(5)), col("x").is_not_null()), col("x"))
821            .otherwise(lit(0))?;
822        assert_not_nullable(&e, &nullable_schema);
823        assert_not_nullable(&e, &not_nullable_schema);
824
825        // CASE WHEN (x = 5 AND x IS NOT NULL) OR (x = bar AND x IS NOT NULL) THEN x ELSE 0
826        let e = when(
827            or(
828                and(col("x").eq(lit(5)), col("x").is_not_null()),
829                and(col("x").eq(col("bar")), col("x").is_not_null()),
830            ),
831            col("x"),
832        )
833        .otherwise(lit(0))?;
834        assert_not_nullable(&e, &nullable_schema);
835        assert_not_nullable(&e, &not_nullable_schema);
836
837        // CASE WHEN x = 5 OR x IS NULL THEN x ELSE 0
838        let e = when(or(col("x").eq(lit(5)), col("x").is_null()), col("x"))
839            .otherwise(lit(0))?;
840        assert_nullable(&e, &nullable_schema);
841        assert_not_nullable(&e, &not_nullable_schema);
842
843        // CASE WHEN x IS TRUE THEN x ELSE 0
844        let e = when(col("x").is_true(), col("x")).otherwise(lit(0))?;
845        assert_not_nullable(&e, &nullable_schema);
846        assert_not_nullable(&e, &not_nullable_schema);
847
848        // CASE WHEN x IS NOT TRUE THEN x ELSE 0
849        let e = when(col("x").is_not_true(), col("x")).otherwise(lit(0))?;
850        assert_nullable(&e, &nullable_schema);
851        assert_not_nullable(&e, &not_nullable_schema);
852
853        // CASE WHEN x IS FALSE THEN x ELSE 0
854        let e = when(col("x").is_false(), col("x")).otherwise(lit(0))?;
855        assert_not_nullable(&e, &nullable_schema);
856        assert_not_nullable(&e, &not_nullable_schema);
857
858        // CASE WHEN x IS NOT FALSE THEN x ELSE 0
859        let e = when(col("x").is_not_false(), col("x")).otherwise(lit(0))?;
860        assert_nullable(&e, &nullable_schema);
861        assert_not_nullable(&e, &not_nullable_schema);
862
863        // CASE WHEN x IS UNKNOWN THEN x ELSE 0
864        let e = when(col("x").is_unknown(), col("x")).otherwise(lit(0))?;
865        assert_nullable(&e, &nullable_schema);
866        assert_not_nullable(&e, &not_nullable_schema);
867
868        // CASE WHEN x IS NOT UNKNOWN THEN x ELSE 0
869        let e = when(col("x").is_not_unknown(), col("x")).otherwise(lit(0))?;
870        assert_not_nullable(&e, &nullable_schema);
871        assert_not_nullable(&e, &not_nullable_schema);
872
873        // CASE WHEN x LIKE 'x' THEN x ELSE 0
874        let e = when(col("x").like(lit("x")), col("x")).otherwise(lit(0))?;
875        assert_not_nullable(&e, &nullable_schema);
876        assert_not_nullable(&e, &not_nullable_schema);
877
878        // CASE WHEN 0 THEN x ELSE 0
879        let e = when(lit(0), col("x")).otherwise(lit(0))?;
880        assert_not_nullable(&e, &nullable_schema);
881        assert_not_nullable(&e, &not_nullable_schema);
882
883        // CASE WHEN 1 THEN x ELSE 0
884        let e = when(lit(1), col("x")).otherwise(lit(0))?;
885        assert_nullable(&e, &nullable_schema);
886        assert_not_nullable(&e, &not_nullable_schema);
887
888        Ok(())
889    }
890
891    #[test]
892    fn test_inlist_nullability() {
893        let get_schema = |nullable| {
894            MockExprSchema::new()
895                .with_data_type(DataType::Int32)
896                .with_nullable(nullable)
897        };
898
899        let expr = col("foo").in_list(vec![lit(1); 5], false);
900        assert!(!expr.nullable(&get_schema(false)).unwrap());
901        assert!(expr.nullable(&get_schema(true)).unwrap());
902        // Testing nullable() returns an error.
903        assert!(
904            expr.nullable(&get_schema(false).with_error_on_nullable(true))
905                .is_err()
906        );
907
908        let null = lit(ScalarValue::Int32(None));
909        let expr = col("foo").in_list(vec![null, lit(1)], false);
910        assert!(expr.nullable(&get_schema(false)).unwrap());
911
912        // Testing on long list
913        let expr = col("foo").in_list(vec![lit(1); 6], false);
914        assert!(expr.nullable(&get_schema(false)).unwrap());
915    }
916
917    #[test]
918    fn test_like_nullability() {
919        let get_schema = |nullable| {
920            MockExprSchema::new()
921                .with_data_type(DataType::Utf8)
922                .with_nullable(nullable)
923        };
924
925        let expr = col("foo").like(lit("bar"));
926        assert!(!expr.nullable(&get_schema(false)).unwrap());
927        assert!(expr.nullable(&get_schema(true)).unwrap());
928
929        let expr = col("foo").like(lit(ScalarValue::Utf8(None)));
930        assert!(expr.nullable(&get_schema(false)).unwrap());
931    }
932
933    #[test]
934    fn expr_schema_data_type() {
935        let expr = col("foo");
936        assert_eq!(
937            DataType::Utf8,
938            expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8))
939                .unwrap()
940        );
941    }
942
943    #[test]
944    fn test_expr_metadata() {
945        let mut meta = HashMap::new();
946        meta.insert("bar".to_string(), "buzz".to_string());
947        let meta = FieldMetadata::from(meta);
948        let expr = col("foo");
949        let schema = MockExprSchema::new()
950            .with_data_type(DataType::Int32)
951            .with_metadata(meta.clone());
952
953        // col, alias, and cast should be metadata-preserving
954        assert_eq!(meta, expr.metadata(&schema).unwrap());
955        assert_eq!(meta, expr.clone().alias("bar").metadata(&schema).unwrap());
956        assert_eq!(
957            meta,
958            expr.clone()
959                .cast_to(&DataType::Int64, &schema)
960                .unwrap()
961                .metadata(&schema)
962                .unwrap()
963        );
964
965        let schema = DFSchema::from_unqualified_fields(
966            vec![meta.add_to_field(Field::new("foo", DataType::Int32, true))].into(),
967            HashMap::new(),
968        )
969        .unwrap();
970
971        // verify to_field method populates metadata
972        assert_eq!(meta, expr.metadata(&schema).unwrap());
973
974        // outer ref constructed by `out_ref_col_with_metadata` should be metadata-preserving
975        let outer_ref = out_ref_col_with_metadata(
976            DataType::Int32,
977            meta.to_hashmap(),
978            Column::from_name("foo"),
979        );
980        assert_eq!(meta, outer_ref.metadata(&schema).unwrap());
981    }
982
983    #[test]
984    fn test_expr_placeholder() {
985        let schema = MockExprSchema::new();
986
987        let mut placeholder_meta = HashMap::new();
988        placeholder_meta.insert("bar".to_string(), "buzz".to_string());
989        let placeholder_meta = FieldMetadata::from(placeholder_meta);
990
991        let expr = Expr::Placeholder(Placeholder::new_with_field(
992            "".to_string(),
993            Some(
994                Field::new("", DataType::Utf8, true)
995                    .with_metadata(placeholder_meta.to_hashmap())
996                    .into(),
997            ),
998        ));
999
1000        let field = expr.to_field(&schema).unwrap().1;
1001        assert_eq!(
1002            (field.data_type(), field.is_nullable()),
1003            (&DataType::Utf8, true)
1004        );
1005        assert_eq!(placeholder_meta, expr.metadata(&schema).unwrap());
1006
1007        let expr_alias = expr.alias("a placeholder by any other name");
1008        let expr_alias_field = expr_alias.to_field(&schema).unwrap().1;
1009        assert_eq!(
1010            (expr_alias_field.data_type(), expr_alias_field.is_nullable()),
1011            (&DataType::Utf8, true)
1012        );
1013        assert_eq!(placeholder_meta, expr_alias.metadata(&schema).unwrap());
1014
1015        // Non-nullable placeholder field should remain non-nullable
1016        let expr = Expr::Placeholder(Placeholder::new_with_field(
1017            "".to_string(),
1018            Some(Field::new("", DataType::Utf8, false).into()),
1019        ));
1020        let expr_field = expr.to_field(&schema).unwrap().1;
1021        assert_eq!(
1022            (expr_field.data_type(), expr_field.is_nullable()),
1023            (&DataType::Utf8, false)
1024        );
1025
1026        let expr_alias = expr.alias("a placeholder by any other name");
1027        let expr_alias_field = expr_alias.to_field(&schema).unwrap().1;
1028        assert_eq!(
1029            (expr_alias_field.data_type(), expr_alias_field.is_nullable()),
1030            (&DataType::Utf8, false)
1031        );
1032    }
1033
1034    #[derive(Debug)]
1035    struct MockExprSchema {
1036        field: FieldRef,
1037        error_on_nullable: bool,
1038    }
1039
1040    impl MockExprSchema {
1041        fn new() -> Self {
1042            Self {
1043                field: Arc::new(Field::new("mock_field", DataType::Null, false)),
1044                error_on_nullable: false,
1045            }
1046        }
1047
1048        fn with_nullable(mut self, nullable: bool) -> Self {
1049            Arc::make_mut(&mut self.field).set_nullable(nullable);
1050            self
1051        }
1052
1053        fn with_data_type(mut self, data_type: DataType) -> Self {
1054            Arc::make_mut(&mut self.field).set_data_type(data_type);
1055            self
1056        }
1057
1058        fn with_error_on_nullable(mut self, error_on_nullable: bool) -> Self {
1059            self.error_on_nullable = error_on_nullable;
1060            self
1061        }
1062
1063        fn with_metadata(mut self, metadata: FieldMetadata) -> Self {
1064            self.field =
1065                Arc::new(metadata.add_to_field(Arc::unwrap_or_clone(self.field)));
1066            self
1067        }
1068    }
1069
1070    impl ExprSchema for MockExprSchema {
1071        fn nullable(&self, _col: &Column) -> Result<bool> {
1072            assert_or_internal_err!(!self.error_on_nullable, "nullable error");
1073            Ok(self.field.is_nullable())
1074        }
1075
1076        fn field_from_column(&self, _col: &Column) -> Result<&FieldRef> {
1077            Ok(&self.field)
1078        }
1079    }
1080
1081    #[test]
1082    fn test_scalar_variable() {
1083        let mut meta = HashMap::new();
1084        meta.insert("bar".to_string(), "buzz".to_string());
1085        let meta = FieldMetadata::from(meta);
1086
1087        let field = Field::new("foo", DataType::Int32, true);
1088        let field = meta.add_to_field(field);
1089        let field = Arc::new(field);
1090
1091        let expr = Expr::ScalarVariable(field, vec!["foo".to_string()]);
1092
1093        let schema = MockExprSchema::new();
1094
1095        assert_eq!(meta, expr.metadata(&schema).unwrap());
1096    }
1097}