datafusion_expr/
expr_schema.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::{Between, Expr, Like};
19use crate::expr::{
20    AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList,
21    InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction,
22    WindowFunctionParams,
23};
24use crate::type_coercion::functions::{
25    data_types_with_scalar_udf, fields_with_aggregate_udf, fields_with_window_udf,
26};
27use crate::udf::ReturnFieldArgs;
28use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition};
29use arrow::compute::can_cast_types;
30use arrow::datatypes::{DataType, Field, FieldRef};
31use datafusion_common::{
32    not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema,
33    Result, Spans, TableReference,
34};
35use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer;
36use datafusion_functions_window_common::field::WindowUDFFieldArgs;
37use std::collections::HashMap;
38use std::sync::Arc;
39
40/// Trait to allow expr to typable with respect to a schema
41pub trait ExprSchemable {
42    /// Given a schema, return the type of the expr
43    fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType>;
44
45    /// Given a schema, return the nullability of the expr
46    fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool>;
47
48    /// Given a schema, return the expr's optional metadata
49    fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>>;
50
51    /// Convert to a field with respect to a schema
52    fn to_field(
53        &self,
54        input_schema: &dyn ExprSchema,
55    ) -> Result<(Option<TableReference>, Arc<Field>)>;
56
57    /// Cast to a type with respect to a schema
58    fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr>;
59
60    /// Given a schema, return the type and nullability of the expr
61    fn data_type_and_nullable(&self, schema: &dyn ExprSchema)
62        -> Result<(DataType, bool)>;
63}
64
65impl ExprSchemable for Expr {
66    /// Returns the [arrow::datatypes::DataType] of the expression
67    /// based on [ExprSchema]
68    ///
69    /// Note: [`DFSchema`] implements [ExprSchema].
70    ///
71    /// [`DFSchema`]: datafusion_common::DFSchema
72    ///
73    /// # Examples
74    ///
75    /// Get the type of an expression that adds 2 columns. Adding an Int32
76    /// and Float32 results in Float32 type
77    ///
78    /// ```
79    /// # use arrow::datatypes::{DataType, Field};
80    /// # use datafusion_common::DFSchema;
81    /// # use datafusion_expr::{col, ExprSchemable};
82    /// # use std::collections::HashMap;
83    ///
84    /// fn main() {
85    ///   let expr = col("c1") + col("c2");
86    ///   let schema = DFSchema::from_unqualified_fields(
87    ///     vec![
88    ///       Field::new("c1", DataType::Int32, true),
89    ///       Field::new("c2", DataType::Float32, true),
90    ///       ].into(),
91    ///       HashMap::new(),
92    ///   ).unwrap();
93    ///   assert_eq!("Float32", format!("{}", expr.get_type(&schema).unwrap()));
94    /// }
95    /// ```
96    ///
97    /// # Errors
98    ///
99    /// This function errors when it is not possible to compute its
100    /// [arrow::datatypes::DataType].  This happens when e.g. the
101    /// expression refers to a column that does not exist in the
102    /// schema, or when the expression is incorrectly typed
103    /// (e.g. `[utf8] + [bool]`).
104    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
105    fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType> {
106        match self {
107            Expr::Alias(Alias { expr, name, .. }) => match &**expr {
108                Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type {
109                    None => schema.data_type(&Column::from_name(name)).cloned(),
110                    Some(dt) => Ok(dt.clone()),
111                },
112                _ => expr.get_type(schema),
113            },
114            Expr::Negative(expr) => expr.get_type(schema),
115            Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
116            Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()),
117            Expr::ScalarVariable(ty, _) => Ok(ty.clone()),
118            Expr::Literal(l, _) => Ok(l.data_type()),
119            Expr::Case(case) => {
120                for (_, then_expr) in &case.when_then_expr {
121                    let then_type = then_expr.get_type(schema)?;
122                    if !then_type.is_null() {
123                        return Ok(then_type);
124                    }
125                }
126                case.else_expr
127                    .as_ref()
128                    .map_or(Ok(DataType::Null), |e| e.get_type(schema))
129            }
130            Expr::Cast(Cast { data_type, .. })
131            | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()),
132            Expr::Unnest(Unnest { expr }) => {
133                let arg_data_type = expr.get_type(schema)?;
134                // Unnest's output type is the inner type of the list
135                match arg_data_type {
136                    DataType::List(field)
137                    | DataType::LargeList(field)
138                    | DataType::FixedSizeList(field, _) => Ok(field.data_type().clone()),
139                    DataType::Struct(_) => Ok(arg_data_type),
140                    DataType::Null => {
141                        not_impl_err!("unnest() does not support null yet")
142                    }
143                    _ => {
144                        plan_err!(
145                            "unnest() can only be applied to array, struct and null"
146                        )
147                    }
148                }
149            }
150            Expr::ScalarFunction(_func) => {
151                let (return_type, _) = self.data_type_and_nullable(schema)?;
152                Ok(return_type)
153            }
154            Expr::WindowFunction(window_function) => self
155                .data_type_and_nullable_with_window_function(schema, window_function)
156                .map(|(return_type, _)| return_type),
157            Expr::AggregateFunction(AggregateFunction {
158                func,
159                params: AggregateFunctionParams { args, .. },
160            }) => {
161                let fields = args
162                    .iter()
163                    .map(|e| e.to_field(schema).map(|(_, f)| f))
164                    .collect::<Result<Vec<_>>>()?;
165                let new_fields = fields_with_aggregate_udf(&fields, func)
166                    .map_err(|err| {
167                        let data_types = fields
168                            .iter()
169                            .map(|f| f.data_type().clone())
170                            .collect::<Vec<_>>();
171                        plan_datafusion_err!(
172                            "{} {}",
173                            match err {
174                                DataFusionError::Plan(msg) => msg,
175                                err => err.to_string(),
176                            },
177                            utils::generate_signature_error_msg(
178                                func.name(),
179                                func.signature().clone(),
180                                &data_types
181                            )
182                        )
183                    })?
184                    .into_iter()
185                    .collect::<Vec<_>>();
186                Ok(func.return_field(&new_fields)?.data_type().clone())
187            }
188            Expr::Not(_)
189            | Expr::IsNull(_)
190            | Expr::Exists { .. }
191            | Expr::InSubquery(_)
192            | Expr::Between { .. }
193            | Expr::InList { .. }
194            | Expr::IsNotNull(_)
195            | Expr::IsTrue(_)
196            | Expr::IsFalse(_)
197            | Expr::IsUnknown(_)
198            | Expr::IsNotTrue(_)
199            | Expr::IsNotFalse(_)
200            | Expr::IsNotUnknown(_) => Ok(DataType::Boolean),
201            Expr::ScalarSubquery(subquery) => {
202                Ok(subquery.subquery.schema().field(0).data_type().clone())
203            }
204            Expr::BinaryExpr(BinaryExpr {
205                ref left,
206                ref right,
207                ref op,
208            }) => BinaryTypeCoercer::new(
209                &left.get_type(schema)?,
210                op,
211                &right.get_type(schema)?,
212            )
213            .get_result_type(),
214            Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean),
215            Expr::Placeholder(Placeholder { data_type, .. }) => {
216                if let Some(dtype) = data_type {
217                    Ok(dtype.clone())
218                } else {
219                    // If the placeholder's type hasn't been specified, treat it as
220                    // null (unspecified placeholders generate an error during planning)
221                    Ok(DataType::Null)
222                }
223            }
224            #[expect(deprecated)]
225            Expr::Wildcard { .. } => Ok(DataType::Null),
226            Expr::GroupingSet(_) => {
227                // Grouping sets do not really have a type and do not appear in projections
228                Ok(DataType::Null)
229            }
230        }
231    }
232
233    /// Returns the nullability of the expression based on [ExprSchema].
234    ///
235    /// Note: [`DFSchema`] implements [ExprSchema].
236    ///
237    /// [`DFSchema`]: datafusion_common::DFSchema
238    ///
239    /// # Errors
240    ///
241    /// This function errors when it is not possible to compute its
242    /// nullability.  This happens when the expression refers to a
243    /// column that does not exist in the schema.
244    fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool> {
245        match self {
246            Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::Negative(expr) => {
247                expr.nullable(input_schema)
248            }
249
250            Expr::InList(InList { expr, list, .. }) => {
251                // Avoid inspecting too many expressions.
252                const MAX_INSPECT_LIMIT: usize = 6;
253                // Stop if a nullable expression is found or an error occurs.
254                let has_nullable = std::iter::once(expr.as_ref())
255                    .chain(list)
256                    .take(MAX_INSPECT_LIMIT)
257                    .find_map(|e| {
258                        e.nullable(input_schema)
259                            .map(|nullable| if nullable { Some(()) } else { None })
260                            .transpose()
261                    })
262                    .transpose()?;
263                Ok(match has_nullable {
264                    // If a nullable subexpression is found, the result may also be nullable.
265                    Some(_) => true,
266                    // If the list is too long, we assume it is nullable.
267                    None if list.len() + 1 > MAX_INSPECT_LIMIT => true,
268                    // All the subexpressions are non-nullable, so the result must be non-nullable.
269                    _ => false,
270                })
271            }
272
273            Expr::Between(Between {
274                expr, low, high, ..
275            }) => Ok(expr.nullable(input_schema)?
276                || low.nullable(input_schema)?
277                || high.nullable(input_schema)?),
278
279            Expr::Column(c) => input_schema.nullable(c),
280            Expr::OuterReferenceColumn(_, _) => Ok(true),
281            Expr::Literal(value, _) => Ok(value.is_null()),
282            Expr::Case(case) => {
283                // This expression is nullable if any of the input expressions are nullable
284                let then_nullable = case
285                    .when_then_expr
286                    .iter()
287                    .map(|(_, t)| t.nullable(input_schema))
288                    .collect::<Result<Vec<_>>>()?;
289                if then_nullable.contains(&true) {
290                    Ok(true)
291                } else if let Some(e) = &case.else_expr {
292                    e.nullable(input_schema)
293                } else {
294                    // CASE produces NULL if there is no `else` expr
295                    // (aka when none of the `when_then_exprs` match)
296                    Ok(true)
297                }
298            }
299            Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
300            Expr::ScalarFunction(_func) => {
301                let (_, nullable) = self.data_type_and_nullable(input_schema)?;
302                Ok(nullable)
303            }
304            Expr::AggregateFunction(AggregateFunction { func, .. }) => {
305                Ok(func.is_nullable())
306            }
307            Expr::WindowFunction(window_function) => self
308                .data_type_and_nullable_with_window_function(
309                    input_schema,
310                    window_function,
311                )
312                .map(|(_, nullable)| nullable),
313            Expr::ScalarVariable(_, _)
314            | Expr::TryCast { .. }
315            | Expr::Unnest(_)
316            | Expr::Placeholder(_) => Ok(true),
317            Expr::IsNull(_)
318            | Expr::IsNotNull(_)
319            | Expr::IsTrue(_)
320            | Expr::IsFalse(_)
321            | Expr::IsUnknown(_)
322            | Expr::IsNotTrue(_)
323            | Expr::IsNotFalse(_)
324            | Expr::IsNotUnknown(_)
325            | Expr::Exists { .. } => Ok(false),
326            Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema),
327            Expr::ScalarSubquery(subquery) => {
328                Ok(subquery.subquery.schema().field(0).is_nullable())
329            }
330            Expr::BinaryExpr(BinaryExpr {
331                ref left,
332                ref right,
333                ..
334            }) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?),
335            Expr::Like(Like { expr, pattern, .. })
336            | Expr::SimilarTo(Like { expr, pattern, .. }) => {
337                Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?)
338            }
339            #[expect(deprecated)]
340            Expr::Wildcard { .. } => Ok(false),
341            Expr::GroupingSet(_) => {
342                // Grouping sets do not really have the concept of nullable and do not appear
343                // in projections
344                Ok(true)
345            }
346        }
347    }
348
349    fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>> {
350        self.to_field(schema)
351            .map(|(_, field)| field.metadata().clone())
352    }
353
354    /// Returns the datatype and nullability of the expression based on [ExprSchema].
355    ///
356    /// Note: [`DFSchema`] implements [ExprSchema].
357    ///
358    /// [`DFSchema`]: datafusion_common::DFSchema
359    ///
360    /// # Errors
361    ///
362    /// This function errors when it is not possible to compute its
363    /// datatype or nullability.
364    fn data_type_and_nullable(
365        &self,
366        schema: &dyn ExprSchema,
367    ) -> Result<(DataType, bool)> {
368        let field = self.to_field(schema)?.1;
369
370        Ok((field.data_type().clone(), field.is_nullable()))
371    }
372
373    /// Returns a [arrow::datatypes::Field] compatible with this expression.
374    ///
375    /// So for example, a projected expression `col(c1) + col(c2)` is
376    /// placed in an output field **named** col("c1 + c2")
377    fn to_field(
378        &self,
379        schema: &dyn ExprSchema,
380    ) -> Result<(Option<TableReference>, Arc<Field>)> {
381        let (relation, schema_name) = self.qualified_name();
382        #[allow(deprecated)]
383        let field = match self {
384            Expr::Alias(Alias {
385                expr,
386                name,
387                metadata,
388                ..
389            }) => {
390                let field = match &**expr {
391                    Expr::Placeholder(Placeholder { data_type, .. }) => {
392                        match &data_type {
393                            None => schema
394                                .data_type_and_nullable(&Column::from_name(name))
395                                .map(|(d, n)| Field::new(&schema_name, d.clone(), n)),
396                            Some(dt) => Ok(Field::new(
397                                &schema_name,
398                                dt.clone(),
399                                expr.nullable(schema)?,
400                            )),
401                        }
402                    }
403                    _ => expr.to_field(schema).map(|(_, f)| f.as_ref().clone()),
404                }?;
405
406                let mut combined_metadata = expr.metadata(schema)?;
407                if let Some(metadata) = metadata {
408                    if !metadata.is_empty() {
409                        combined_metadata.extend(metadata.clone());
410                    }
411                }
412
413                Ok(Arc::new(field.with_metadata(combined_metadata)))
414            }
415            Expr::Negative(expr) => expr.to_field(schema).map(|(_, f)| f),
416            Expr::Column(c) => schema.field_from_column(c).map(|f| Arc::new(f.clone())),
417            Expr::OuterReferenceColumn(ty, _) => {
418                Ok(Arc::new(Field::new(&schema_name, ty.clone(), true)))
419            }
420            Expr::ScalarVariable(ty, _) => {
421                Ok(Arc::new(Field::new(&schema_name, ty.clone(), true)))
422            }
423            Expr::Literal(l, metadata) => {
424                let mut field = Field::new(&schema_name, l.data_type(), l.is_null());
425                if let Some(metadata) = metadata {
426                    field = field.with_metadata(
427                        metadata
428                            .iter()
429                            .map(|(k, v)| (k.clone(), v.clone()))
430                            .collect(),
431                    );
432                }
433                Ok(Arc::new(field))
434            }
435            Expr::IsNull(_)
436            | Expr::IsNotNull(_)
437            | Expr::IsTrue(_)
438            | Expr::IsFalse(_)
439            | Expr::IsUnknown(_)
440            | Expr::IsNotTrue(_)
441            | Expr::IsNotFalse(_)
442            | Expr::IsNotUnknown(_)
443            | Expr::Exists { .. } => {
444                Ok(Arc::new(Field::new(&schema_name, DataType::Boolean, false)))
445            }
446            Expr::ScalarSubquery(subquery) => {
447                Ok(Arc::new(subquery.subquery.schema().field(0).clone()))
448            }
449            Expr::BinaryExpr(BinaryExpr {
450                ref left,
451                ref right,
452                ref op,
453            }) => {
454                let (lhs_type, lhs_nullable) = left.data_type_and_nullable(schema)?;
455                let (rhs_type, rhs_nullable) = right.data_type_and_nullable(schema)?;
456                let mut coercer = BinaryTypeCoercer::new(&lhs_type, op, &rhs_type);
457                coercer.set_lhs_spans(left.spans().cloned().unwrap_or_default());
458                coercer.set_rhs_spans(right.spans().cloned().unwrap_or_default());
459                Ok(Arc::new(Field::new(
460                    &schema_name,
461                    coercer.get_result_type()?,
462                    lhs_nullable || rhs_nullable,
463                )))
464            }
465            Expr::WindowFunction(window_function) => {
466                let (dt, nullable) = self.data_type_and_nullable_with_window_function(
467                    schema,
468                    window_function,
469                )?;
470                Ok(Arc::new(Field::new(&schema_name, dt, nullable)))
471            }
472            Expr::AggregateFunction(aggregate_function) => {
473                let AggregateFunction {
474                    func,
475                    params: AggregateFunctionParams { args, .. },
476                    ..
477                } = aggregate_function;
478
479                let fields = args
480                    .iter()
481                    .map(|e| e.to_field(schema).map(|(_, f)| f))
482                    .collect::<Result<Vec<_>>>()?;
483                // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
484                let new_fields = fields_with_aggregate_udf(&fields, func)
485                    .map_err(|err| {
486                        let arg_types = fields
487                            .iter()
488                            .map(|f| f.data_type())
489                            .cloned()
490                            .collect::<Vec<_>>();
491                        plan_datafusion_err!(
492                            "{} {}",
493                            match err {
494                                DataFusionError::Plan(msg) => msg,
495                                err => err.to_string(),
496                            },
497                            utils::generate_signature_error_msg(
498                                func.name(),
499                                func.signature().clone(),
500                                &arg_types,
501                            )
502                        )
503                    })?
504                    .into_iter()
505                    .collect::<Vec<_>>();
506
507                func.return_field(&new_fields)
508            }
509            Expr::ScalarFunction(ScalarFunction { func, args }) => {
510                let (arg_types, fields): (Vec<DataType>, Vec<Arc<Field>>) = args
511                    .iter()
512                    .map(|e| e.to_field(schema).map(|(_, f)| f))
513                    .collect::<Result<Vec<_>>>()?
514                    .into_iter()
515                    .map(|f| (f.data_type().clone(), f))
516                    .unzip();
517                // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
518                let new_data_types = data_types_with_scalar_udf(&arg_types, func)
519                    .map_err(|err| {
520                        plan_datafusion_err!(
521                            "{} {}",
522                            match err {
523                                DataFusionError::Plan(msg) => msg,
524                                err => err.to_string(),
525                            },
526                            utils::generate_signature_error_msg(
527                                func.name(),
528                                func.signature().clone(),
529                                &arg_types,
530                            )
531                        )
532                    })?;
533                let new_fields = fields
534                    .into_iter()
535                    .zip(new_data_types)
536                    .map(|(f, d)| f.as_ref().clone().with_data_type(d))
537                    .map(Arc::new)
538                    .collect::<Vec<FieldRef>>();
539
540                let arguments = args
541                    .iter()
542                    .map(|e| match e {
543                        Expr::Literal(sv, _) => Some(sv),
544                        _ => None,
545                    })
546                    .collect::<Vec<_>>();
547                let args = ReturnFieldArgs {
548                    arg_fields: &new_fields,
549                    scalar_arguments: &arguments,
550                };
551
552                func.return_field_from_args(args)
553            }
554            // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)),
555            Expr::Cast(Cast { expr, data_type }) => expr
556                .to_field(schema)
557                .map(|(_, f)| f.as_ref().clone().with_data_type(data_type.clone()))
558                .map(Arc::new),
559            Expr::Like(_)
560            | Expr::SimilarTo(_)
561            | Expr::Not(_)
562            | Expr::Between(_)
563            | Expr::Case(_)
564            | Expr::TryCast(_)
565            | Expr::InList(_)
566            | Expr::InSubquery(_)
567            | Expr::Wildcard { .. }
568            | Expr::GroupingSet(_)
569            | Expr::Placeholder(_)
570            | Expr::Unnest(_) => Ok(Arc::new(Field::new(
571                &schema_name,
572                self.get_type(schema)?,
573                self.nullable(schema)?,
574            ))),
575        }?;
576
577        Ok((
578            relation,
579            Arc::new(field.as_ref().clone().with_name(schema_name)),
580        ))
581    }
582
583    /// Wraps this expression in a cast to a target [arrow::datatypes::DataType].
584    ///
585    /// # Errors
586    ///
587    /// This function errors when it is impossible to cast the
588    /// expression to the target [arrow::datatypes::DataType].
589    fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr> {
590        let this_type = self.get_type(schema)?;
591        if this_type == *cast_to_type {
592            return Ok(self);
593        }
594
595        // TODO(kszucs): Most of the operations do not validate the type correctness
596        // like all of the binary expressions below. Perhaps Expr should track the
597        // type of the expression?
598
599        if can_cast_types(&this_type, cast_to_type) {
600            match self {
601                Expr::ScalarSubquery(subquery) => {
602                    Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?))
603                }
604                _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))),
605            }
606        } else {
607            plan_err!("Cannot automatically convert {this_type:?} to {cast_to_type:?}")
608        }
609    }
610}
611
612impl Expr {
613    /// Common method for window functions that applies type coercion
614    /// to all arguments of the window function to check if it matches
615    /// its signature.
616    ///
617    /// If successful, this method returns the data type and
618    /// nullability of the window function's result.
619    ///
620    /// Otherwise, returns an error if there's a type mismatch between
621    /// the window function's signature and the provided arguments.
622    fn data_type_and_nullable_with_window_function(
623        &self,
624        schema: &dyn ExprSchema,
625        window_function: &WindowFunction,
626    ) -> Result<(DataType, bool)> {
627        let WindowFunction {
628            fun,
629            params: WindowFunctionParams { args, .. },
630            ..
631        } = window_function;
632
633        let fields = args
634            .iter()
635            .map(|e| e.to_field(schema).map(|(_, f)| f))
636            .collect::<Result<Vec<_>>>()?;
637        match fun {
638            WindowFunctionDefinition::AggregateUDF(udaf) => {
639                let data_types = fields
640                    .iter()
641                    .map(|f| f.data_type())
642                    .cloned()
643                    .collect::<Vec<_>>();
644                let new_fields = fields_with_aggregate_udf(&fields, udaf)
645                    .map_err(|err| {
646                        plan_datafusion_err!(
647                            "{} {}",
648                            match err {
649                                DataFusionError::Plan(msg) => msg,
650                                err => err.to_string(),
651                            },
652                            utils::generate_signature_error_msg(
653                                fun.name(),
654                                fun.signature(),
655                                &data_types
656                            )
657                        )
658                    })?
659                    .into_iter()
660                    .collect::<Vec<_>>();
661
662                let return_field = udaf.return_field(&new_fields)?;
663
664                Ok((return_field.data_type().clone(), return_field.is_nullable()))
665            }
666            WindowFunctionDefinition::WindowUDF(udwf) => {
667                let data_types = fields
668                    .iter()
669                    .map(|f| f.data_type())
670                    .cloned()
671                    .collect::<Vec<_>>();
672                let new_fields = fields_with_window_udf(&fields, udwf)
673                    .map_err(|err| {
674                        plan_datafusion_err!(
675                            "{} {}",
676                            match err {
677                                DataFusionError::Plan(msg) => msg,
678                                err => err.to_string(),
679                            },
680                            utils::generate_signature_error_msg(
681                                fun.name(),
682                                fun.signature(),
683                                &data_types
684                            )
685                        )
686                    })?
687                    .into_iter()
688                    .collect::<Vec<_>>();
689                let (_, function_name) = self.qualified_name();
690                let field_args = WindowUDFFieldArgs::new(&new_fields, &function_name);
691
692                udwf.field(field_args)
693                    .map(|field| (field.data_type().clone(), field.is_nullable()))
694            }
695        }
696    }
697}
698
699/// Cast subquery in InSubquery/ScalarSubquery to a given type.
700///
701/// 1. **Projection plan**: If the subquery is a projection (i.e. a SELECT statement with specific
702///    columns), it casts the first expression in the projection to the target type and creates a
703///    new projection with the casted expression.
704/// 2. **Non-projection plan**: If the subquery isn't a projection, it adds a projection to the plan
705///    with the casted first column.
706///
707pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery> {
708    if subquery.subquery.schema().field(0).data_type() == cast_to_type {
709        return Ok(subquery);
710    }
711
712    let plan = subquery.subquery.as_ref();
713    let new_plan = match plan {
714        LogicalPlan::Projection(projection) => {
715            let cast_expr = projection.expr[0]
716                .clone()
717                .cast_to(cast_to_type, projection.input.schema())?;
718            LogicalPlan::Projection(Projection::try_new(
719                vec![cast_expr],
720                Arc::clone(&projection.input),
721            )?)
722        }
723        _ => {
724            let cast_expr = Expr::Column(Column::from(plan.schema().qualified_field(0)))
725                .cast_to(cast_to_type, subquery.subquery.schema())?;
726            LogicalPlan::Projection(Projection::try_new(
727                vec![cast_expr],
728                subquery.subquery,
729            )?)
730        }
731    };
732    Ok(Subquery {
733        subquery: Arc::new(new_plan),
734        outer_ref_columns: subquery.outer_ref_columns,
735        spans: Spans::new(),
736    })
737}
738
739#[cfg(test)]
740mod tests {
741    use super::*;
742    use crate::{col, lit};
743
744    use datafusion_common::{internal_err, DFSchema, ScalarValue};
745
746    macro_rules! test_is_expr_nullable {
747        ($EXPR_TYPE:ident) => {{
748            let expr = lit(ScalarValue::Null).$EXPR_TYPE();
749            assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
750        }};
751    }
752
753    #[test]
754    fn expr_schema_nullability() {
755        let expr = col("foo").eq(lit(1));
756        assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
757        assert!(expr
758            .nullable(&MockExprSchema::new().with_nullable(true))
759            .unwrap());
760
761        test_is_expr_nullable!(is_null);
762        test_is_expr_nullable!(is_not_null);
763        test_is_expr_nullable!(is_true);
764        test_is_expr_nullable!(is_not_true);
765        test_is_expr_nullable!(is_false);
766        test_is_expr_nullable!(is_not_false);
767        test_is_expr_nullable!(is_unknown);
768        test_is_expr_nullable!(is_not_unknown);
769    }
770
771    #[test]
772    fn test_between_nullability() {
773        let get_schema = |nullable| {
774            MockExprSchema::new()
775                .with_data_type(DataType::Int32)
776                .with_nullable(nullable)
777        };
778
779        let expr = col("foo").between(lit(1), lit(2));
780        assert!(!expr.nullable(&get_schema(false)).unwrap());
781        assert!(expr.nullable(&get_schema(true)).unwrap());
782
783        let null = lit(ScalarValue::Int32(None));
784
785        let expr = col("foo").between(null.clone(), lit(2));
786        assert!(expr.nullable(&get_schema(false)).unwrap());
787
788        let expr = col("foo").between(lit(1), null.clone());
789        assert!(expr.nullable(&get_schema(false)).unwrap());
790
791        let expr = col("foo").between(null.clone(), null);
792        assert!(expr.nullable(&get_schema(false)).unwrap());
793    }
794
795    #[test]
796    fn test_inlist_nullability() {
797        let get_schema = |nullable| {
798            MockExprSchema::new()
799                .with_data_type(DataType::Int32)
800                .with_nullable(nullable)
801        };
802
803        let expr = col("foo").in_list(vec![lit(1); 5], false);
804        assert!(!expr.nullable(&get_schema(false)).unwrap());
805        assert!(expr.nullable(&get_schema(true)).unwrap());
806        // Testing nullable() returns an error.
807        assert!(expr
808            .nullable(&get_schema(false).with_error_on_nullable(true))
809            .is_err());
810
811        let null = lit(ScalarValue::Int32(None));
812        let expr = col("foo").in_list(vec![null, lit(1)], false);
813        assert!(expr.nullable(&get_schema(false)).unwrap());
814
815        // Testing on long list
816        let expr = col("foo").in_list(vec![lit(1); 6], false);
817        assert!(expr.nullable(&get_schema(false)).unwrap());
818    }
819
820    #[test]
821    fn test_like_nullability() {
822        let get_schema = |nullable| {
823            MockExprSchema::new()
824                .with_data_type(DataType::Utf8)
825                .with_nullable(nullable)
826        };
827
828        let expr = col("foo").like(lit("bar"));
829        assert!(!expr.nullable(&get_schema(false)).unwrap());
830        assert!(expr.nullable(&get_schema(true)).unwrap());
831
832        let expr = col("foo").like(lit(ScalarValue::Utf8(None)));
833        assert!(expr.nullable(&get_schema(false)).unwrap());
834    }
835
836    #[test]
837    fn expr_schema_data_type() {
838        let expr = col("foo");
839        assert_eq!(
840            DataType::Utf8,
841            expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8))
842                .unwrap()
843        );
844    }
845
846    #[test]
847    fn test_expr_metadata() {
848        let mut meta = HashMap::new();
849        meta.insert("bar".to_string(), "buzz".to_string());
850        let expr = col("foo");
851        let schema = MockExprSchema::new()
852            .with_data_type(DataType::Int32)
853            .with_metadata(meta.clone());
854
855        // col, alias, and cast should be metadata-preserving
856        assert_eq!(meta, expr.metadata(&schema).unwrap());
857        assert_eq!(meta, expr.clone().alias("bar").metadata(&schema).unwrap());
858        assert_eq!(
859            meta,
860            expr.clone()
861                .cast_to(&DataType::Int64, &schema)
862                .unwrap()
863                .metadata(&schema)
864                .unwrap()
865        );
866
867        let schema = DFSchema::from_unqualified_fields(
868            vec![Field::new("foo", DataType::Int32, true).with_metadata(meta.clone())]
869                .into(),
870            HashMap::new(),
871        )
872        .unwrap();
873
874        // verify to_field method populates metadata
875        assert_eq!(&meta, expr.to_field(&schema).unwrap().1.metadata());
876    }
877
878    #[derive(Debug)]
879    struct MockExprSchema {
880        field: Field,
881        error_on_nullable: bool,
882    }
883
884    impl MockExprSchema {
885        fn new() -> Self {
886            Self {
887                field: Field::new("mock_field", DataType::Null, false),
888                error_on_nullable: false,
889            }
890        }
891
892        fn with_nullable(mut self, nullable: bool) -> Self {
893            self.field = self.field.with_nullable(nullable);
894            self
895        }
896
897        fn with_data_type(mut self, data_type: DataType) -> Self {
898            self.field = self.field.with_data_type(data_type);
899            self
900        }
901
902        fn with_error_on_nullable(mut self, error_on_nullable: bool) -> Self {
903            self.error_on_nullable = error_on_nullable;
904            self
905        }
906
907        fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
908            self.field = self.field.with_metadata(metadata);
909            self
910        }
911    }
912
913    impl ExprSchema for MockExprSchema {
914        fn nullable(&self, _col: &Column) -> Result<bool> {
915            if self.error_on_nullable {
916                internal_err!("nullable error")
917            } else {
918                Ok(self.field.is_nullable())
919            }
920        }
921
922        fn field_from_column(&self, _col: &Column) -> Result<&Field> {
923            Ok(&self.field)
924        }
925    }
926}