datafusion_expr/
expr_schema.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::{Between, Expr, Like};
19use crate::expr::{
20    AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, FieldMetadata,
21    InList, InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction,
22    WindowFunctionParams,
23};
24use crate::type_coercion::functions::{
25    data_types_with_scalar_udf, fields_with_aggregate_udf, fields_with_window_udf,
26};
27use crate::udf::ReturnFieldArgs;
28use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition};
29use arrow::compute::can_cast_types;
30use arrow::datatypes::{DataType, Field, FieldRef};
31use datafusion_common::{
32    not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema,
33    Result, Spans, TableReference,
34};
35use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer;
36use datafusion_functions_window_common::field::WindowUDFFieldArgs;
37use std::sync::Arc;
38
39/// Trait to allow expr to typable with respect to a schema
40pub trait ExprSchemable {
41    /// Given a schema, return the type of the expr
42    fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType>;
43
44    /// Given a schema, return the nullability of the expr
45    fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool>;
46
47    /// Given a schema, return the expr's optional metadata
48    fn metadata(&self, schema: &dyn ExprSchema) -> Result<FieldMetadata>;
49
50    /// Convert to a field with respect to a schema
51    fn to_field(
52        &self,
53        input_schema: &dyn ExprSchema,
54    ) -> Result<(Option<TableReference>, Arc<Field>)>;
55
56    /// Cast to a type with respect to a schema
57    fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr>;
58
59    /// Given a schema, return the type and nullability of the expr
60    fn data_type_and_nullable(&self, schema: &dyn ExprSchema)
61        -> Result<(DataType, bool)>;
62}
63
64impl ExprSchemable for Expr {
65    /// Returns the [arrow::datatypes::DataType] of the expression
66    /// based on [ExprSchema]
67    ///
68    /// Note: [`DFSchema`] implements [ExprSchema].
69    ///
70    /// [`DFSchema`]: datafusion_common::DFSchema
71    ///
72    /// # Examples
73    ///
74    /// Get the type of an expression that adds 2 columns. Adding an Int32
75    /// and Float32 results in Float32 type
76    ///
77    /// ```
78    /// # use arrow::datatypes::{DataType, Field};
79    /// # use datafusion_common::DFSchema;
80    /// # use datafusion_expr::{col, ExprSchemable};
81    /// # use std::collections::HashMap;
82    ///
83    /// fn main() {
84    ///   let expr = col("c1") + col("c2");
85    ///   let schema = DFSchema::from_unqualified_fields(
86    ///     vec![
87    ///       Field::new("c1", DataType::Int32, true),
88    ///       Field::new("c2", DataType::Float32, true),
89    ///       ].into(),
90    ///       HashMap::new(),
91    ///   ).unwrap();
92    ///   assert_eq!("Float32", format!("{}", expr.get_type(&schema).unwrap()));
93    /// }
94    /// ```
95    ///
96    /// # Errors
97    ///
98    /// This function errors when it is not possible to compute its
99    /// [arrow::datatypes::DataType].  This happens when e.g. the
100    /// expression refers to a column that does not exist in the
101    /// schema, or when the expression is incorrectly typed
102    /// (e.g. `[utf8] + [bool]`).
103    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
104    fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType> {
105        match self {
106            Expr::Alias(Alias { expr, name, .. }) => match &**expr {
107                Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type {
108                    None => schema.data_type(&Column::from_name(name)).cloned(),
109                    Some(dt) => Ok(dt.clone()),
110                },
111                _ => expr.get_type(schema),
112            },
113            Expr::Negative(expr) => expr.get_type(schema),
114            Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
115            Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()),
116            Expr::ScalarVariable(ty, _) => Ok(ty.clone()),
117            Expr::Literal(l, _) => Ok(l.data_type()),
118            Expr::Case(case) => {
119                for (_, then_expr) in &case.when_then_expr {
120                    let then_type = then_expr.get_type(schema)?;
121                    if !then_type.is_null() {
122                        return Ok(then_type);
123                    }
124                }
125                case.else_expr
126                    .as_ref()
127                    .map_or(Ok(DataType::Null), |e| e.get_type(schema))
128            }
129            Expr::Cast(Cast { data_type, .. })
130            | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()),
131            Expr::Unnest(Unnest { expr }) => {
132                let arg_data_type = expr.get_type(schema)?;
133                // Unnest's output type is the inner type of the list
134                match arg_data_type {
135                    DataType::List(field)
136                    | DataType::LargeList(field)
137                    | DataType::FixedSizeList(field, _) => Ok(field.data_type().clone()),
138                    DataType::Struct(_) => Ok(arg_data_type),
139                    DataType::Null => {
140                        not_impl_err!("unnest() does not support null yet")
141                    }
142                    _ => {
143                        plan_err!(
144                            "unnest() can only be applied to array, struct and null"
145                        )
146                    }
147                }
148            }
149            Expr::ScalarFunction(_func) => {
150                let (return_type, _) = self.data_type_and_nullable(schema)?;
151                Ok(return_type)
152            }
153            Expr::WindowFunction(window_function) => self
154                .data_type_and_nullable_with_window_function(schema, window_function)
155                .map(|(return_type, _)| return_type),
156            Expr::AggregateFunction(AggregateFunction {
157                func,
158                params: AggregateFunctionParams { args, .. },
159            }) => {
160                let fields = args
161                    .iter()
162                    .map(|e| e.to_field(schema).map(|(_, f)| f))
163                    .collect::<Result<Vec<_>>>()?;
164                let new_fields = fields_with_aggregate_udf(&fields, func)
165                    .map_err(|err| {
166                        let data_types = fields
167                            .iter()
168                            .map(|f| f.data_type().clone())
169                            .collect::<Vec<_>>();
170                        plan_datafusion_err!(
171                            "{} {}",
172                            match err {
173                                DataFusionError::Plan(msg) => msg,
174                                err => err.to_string(),
175                            },
176                            utils::generate_signature_error_msg(
177                                func.name(),
178                                func.signature().clone(),
179                                &data_types
180                            )
181                        )
182                    })?
183                    .into_iter()
184                    .collect::<Vec<_>>();
185                Ok(func.return_field(&new_fields)?.data_type().clone())
186            }
187            Expr::Not(_)
188            | Expr::IsNull(_)
189            | Expr::Exists { .. }
190            | Expr::InSubquery(_)
191            | Expr::Between { .. }
192            | Expr::InList { .. }
193            | Expr::IsNotNull(_)
194            | Expr::IsTrue(_)
195            | Expr::IsFalse(_)
196            | Expr::IsUnknown(_)
197            | Expr::IsNotTrue(_)
198            | Expr::IsNotFalse(_)
199            | Expr::IsNotUnknown(_) => Ok(DataType::Boolean),
200            Expr::ScalarSubquery(subquery) => {
201                Ok(subquery.subquery.schema().field(0).data_type().clone())
202            }
203            Expr::BinaryExpr(BinaryExpr {
204                ref left,
205                ref right,
206                ref op,
207            }) => BinaryTypeCoercer::new(
208                &left.get_type(schema)?,
209                op,
210                &right.get_type(schema)?,
211            )
212            .get_result_type(),
213            Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean),
214            Expr::Placeholder(Placeholder { data_type, .. }) => {
215                if let Some(dtype) = data_type {
216                    Ok(dtype.clone())
217                } else {
218                    // If the placeholder's type hasn't been specified, treat it as
219                    // null (unspecified placeholders generate an error during planning)
220                    Ok(DataType::Null)
221                }
222            }
223            #[expect(deprecated)]
224            Expr::Wildcard { .. } => Ok(DataType::Null),
225            Expr::GroupingSet(_) => {
226                // Grouping sets do not really have a type and do not appear in projections
227                Ok(DataType::Null)
228            }
229        }
230    }
231
232    /// Returns the nullability of the expression based on [ExprSchema].
233    ///
234    /// Note: [`DFSchema`] implements [ExprSchema].
235    ///
236    /// [`DFSchema`]: datafusion_common::DFSchema
237    ///
238    /// # Errors
239    ///
240    /// This function errors when it is not possible to compute its
241    /// nullability.  This happens when the expression refers to a
242    /// column that does not exist in the schema.
243    fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool> {
244        match self {
245            Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::Negative(expr) => {
246                expr.nullable(input_schema)
247            }
248
249            Expr::InList(InList { expr, list, .. }) => {
250                // Avoid inspecting too many expressions.
251                const MAX_INSPECT_LIMIT: usize = 6;
252                // Stop if a nullable expression is found or an error occurs.
253                let has_nullable = std::iter::once(expr.as_ref())
254                    .chain(list)
255                    .take(MAX_INSPECT_LIMIT)
256                    .find_map(|e| {
257                        e.nullable(input_schema)
258                            .map(|nullable| if nullable { Some(()) } else { None })
259                            .transpose()
260                    })
261                    .transpose()?;
262                Ok(match has_nullable {
263                    // If a nullable subexpression is found, the result may also be nullable.
264                    Some(_) => true,
265                    // If the list is too long, we assume it is nullable.
266                    None if list.len() + 1 > MAX_INSPECT_LIMIT => true,
267                    // All the subexpressions are non-nullable, so the result must be non-nullable.
268                    _ => false,
269                })
270            }
271
272            Expr::Between(Between {
273                expr, low, high, ..
274            }) => Ok(expr.nullable(input_schema)?
275                || low.nullable(input_schema)?
276                || high.nullable(input_schema)?),
277
278            Expr::Column(c) => input_schema.nullable(c),
279            Expr::OuterReferenceColumn(_, _) => Ok(true),
280            Expr::Literal(value, _) => Ok(value.is_null()),
281            Expr::Case(case) => {
282                // This expression is nullable if any of the input expressions are nullable
283                let then_nullable = case
284                    .when_then_expr
285                    .iter()
286                    .map(|(_, t)| t.nullable(input_schema))
287                    .collect::<Result<Vec<_>>>()?;
288                if then_nullable.contains(&true) {
289                    Ok(true)
290                } else if let Some(e) = &case.else_expr {
291                    e.nullable(input_schema)
292                } else {
293                    // CASE produces NULL if there is no `else` expr
294                    // (aka when none of the `when_then_exprs` match)
295                    Ok(true)
296                }
297            }
298            Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
299            Expr::ScalarFunction(_func) => {
300                let (_, nullable) = self.data_type_and_nullable(input_schema)?;
301                Ok(nullable)
302            }
303            Expr::AggregateFunction(AggregateFunction { func, .. }) => {
304                Ok(func.is_nullable())
305            }
306            Expr::WindowFunction(window_function) => self
307                .data_type_and_nullable_with_window_function(
308                    input_schema,
309                    window_function,
310                )
311                .map(|(_, nullable)| nullable),
312            Expr::ScalarVariable(_, _)
313            | Expr::TryCast { .. }
314            | Expr::Unnest(_)
315            | Expr::Placeholder(_) => Ok(true),
316            Expr::IsNull(_)
317            | Expr::IsNotNull(_)
318            | Expr::IsTrue(_)
319            | Expr::IsFalse(_)
320            | Expr::IsUnknown(_)
321            | Expr::IsNotTrue(_)
322            | Expr::IsNotFalse(_)
323            | Expr::IsNotUnknown(_)
324            | Expr::Exists { .. } => Ok(false),
325            Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema),
326            Expr::ScalarSubquery(subquery) => {
327                Ok(subquery.subquery.schema().field(0).is_nullable())
328            }
329            Expr::BinaryExpr(BinaryExpr {
330                ref left,
331                ref right,
332                ..
333            }) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?),
334            Expr::Like(Like { expr, pattern, .. })
335            | Expr::SimilarTo(Like { expr, pattern, .. }) => {
336                Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?)
337            }
338            #[expect(deprecated)]
339            Expr::Wildcard { .. } => Ok(false),
340            Expr::GroupingSet(_) => {
341                // Grouping sets do not really have the concept of nullable and do not appear
342                // in projections
343                Ok(true)
344            }
345        }
346    }
347
348    fn metadata(&self, schema: &dyn ExprSchema) -> Result<FieldMetadata> {
349        self.to_field(schema)
350            .map(|(_, field)| FieldMetadata::from(field.metadata()))
351    }
352
353    /// Returns the datatype and nullability of the expression based on [ExprSchema].
354    ///
355    /// Note: [`DFSchema`] implements [ExprSchema].
356    ///
357    /// [`DFSchema`]: datafusion_common::DFSchema
358    ///
359    /// # Errors
360    ///
361    /// This function errors when it is not possible to compute its
362    /// datatype or nullability.
363    fn data_type_and_nullable(
364        &self,
365        schema: &dyn ExprSchema,
366    ) -> Result<(DataType, bool)> {
367        let field = self.to_field(schema)?.1;
368
369        Ok((field.data_type().clone(), field.is_nullable()))
370    }
371
372    /// Returns a [arrow::datatypes::Field] compatible with this expression.
373    ///
374    /// So for example, a projected expression `col(c1) + col(c2)` is
375    /// placed in an output field **named** col("c1 + c2")
376    fn to_field(
377        &self,
378        schema: &dyn ExprSchema,
379    ) -> Result<(Option<TableReference>, Arc<Field>)> {
380        let (relation, schema_name) = self.qualified_name();
381        #[allow(deprecated)]
382        let field = match self {
383            Expr::Alias(Alias {
384                expr,
385                name,
386                metadata,
387                ..
388            }) => {
389                let field = match &**expr {
390                    Expr::Placeholder(Placeholder { data_type, .. }) => {
391                        match &data_type {
392                            None => schema
393                                .data_type_and_nullable(&Column::from_name(name))
394                                .map(|(d, n)| Field::new(&schema_name, d.clone(), n)),
395                            Some(dt) => Ok(Field::new(
396                                &schema_name,
397                                dt.clone(),
398                                expr.nullable(schema)?,
399                            )),
400                        }
401                    }
402                    _ => expr.to_field(schema).map(|(_, f)| f.as_ref().clone()),
403                }?;
404
405                let mut combined_metadata = expr.metadata(schema)?;
406                if let Some(metadata) = metadata {
407                    combined_metadata.extend(metadata.clone());
408                }
409
410                Ok(Arc::new(combined_metadata.add_to_field(field)))
411            }
412            Expr::Negative(expr) => expr.to_field(schema).map(|(_, f)| f),
413            Expr::Column(c) => schema.field_from_column(c).map(|f| Arc::new(f.clone())),
414            Expr::OuterReferenceColumn(ty, _) => {
415                Ok(Arc::new(Field::new(&schema_name, ty.clone(), true)))
416            }
417            Expr::ScalarVariable(ty, _) => {
418                Ok(Arc::new(Field::new(&schema_name, ty.clone(), true)))
419            }
420            Expr::Literal(l, metadata) => {
421                let mut field = Field::new(&schema_name, l.data_type(), l.is_null());
422                if let Some(metadata) = metadata {
423                    field = metadata.add_to_field(field);
424                }
425                Ok(Arc::new(field))
426            }
427            Expr::IsNull(_)
428            | Expr::IsNotNull(_)
429            | Expr::IsTrue(_)
430            | Expr::IsFalse(_)
431            | Expr::IsUnknown(_)
432            | Expr::IsNotTrue(_)
433            | Expr::IsNotFalse(_)
434            | Expr::IsNotUnknown(_)
435            | Expr::Exists { .. } => {
436                Ok(Arc::new(Field::new(&schema_name, DataType::Boolean, false)))
437            }
438            Expr::ScalarSubquery(subquery) => {
439                Ok(Arc::new(subquery.subquery.schema().field(0).clone()))
440            }
441            Expr::BinaryExpr(BinaryExpr {
442                ref left,
443                ref right,
444                ref op,
445            }) => {
446                let (lhs_type, lhs_nullable) = left.data_type_and_nullable(schema)?;
447                let (rhs_type, rhs_nullable) = right.data_type_and_nullable(schema)?;
448                let mut coercer = BinaryTypeCoercer::new(&lhs_type, op, &rhs_type);
449                coercer.set_lhs_spans(left.spans().cloned().unwrap_or_default());
450                coercer.set_rhs_spans(right.spans().cloned().unwrap_or_default());
451                Ok(Arc::new(Field::new(
452                    &schema_name,
453                    coercer.get_result_type()?,
454                    lhs_nullable || rhs_nullable,
455                )))
456            }
457            Expr::WindowFunction(window_function) => {
458                let (dt, nullable) = self.data_type_and_nullable_with_window_function(
459                    schema,
460                    window_function,
461                )?;
462                Ok(Arc::new(Field::new(&schema_name, dt, nullable)))
463            }
464            Expr::AggregateFunction(aggregate_function) => {
465                let AggregateFunction {
466                    func,
467                    params: AggregateFunctionParams { args, .. },
468                    ..
469                } = aggregate_function;
470
471                let fields = args
472                    .iter()
473                    .map(|e| e.to_field(schema).map(|(_, f)| f))
474                    .collect::<Result<Vec<_>>>()?;
475                // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
476                let new_fields = fields_with_aggregate_udf(&fields, func)
477                    .map_err(|err| {
478                        let arg_types = fields
479                            .iter()
480                            .map(|f| f.data_type())
481                            .cloned()
482                            .collect::<Vec<_>>();
483                        plan_datafusion_err!(
484                            "{} {}",
485                            match err {
486                                DataFusionError::Plan(msg) => msg,
487                                err => err.to_string(),
488                            },
489                            utils::generate_signature_error_msg(
490                                func.name(),
491                                func.signature().clone(),
492                                &arg_types,
493                            )
494                        )
495                    })?
496                    .into_iter()
497                    .collect::<Vec<_>>();
498
499                func.return_field(&new_fields)
500            }
501            Expr::ScalarFunction(ScalarFunction { func, args }) => {
502                let (arg_types, fields): (Vec<DataType>, Vec<Arc<Field>>) = args
503                    .iter()
504                    .map(|e| e.to_field(schema).map(|(_, f)| f))
505                    .collect::<Result<Vec<_>>>()?
506                    .into_iter()
507                    .map(|f| (f.data_type().clone(), f))
508                    .unzip();
509                // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
510                let new_data_types = data_types_with_scalar_udf(&arg_types, func)
511                    .map_err(|err| {
512                        plan_datafusion_err!(
513                            "{} {}",
514                            match err {
515                                DataFusionError::Plan(msg) => msg,
516                                err => err.to_string(),
517                            },
518                            utils::generate_signature_error_msg(
519                                func.name(),
520                                func.signature().clone(),
521                                &arg_types,
522                            )
523                        )
524                    })?;
525                let new_fields = fields
526                    .into_iter()
527                    .zip(new_data_types)
528                    .map(|(f, d)| f.as_ref().clone().with_data_type(d))
529                    .map(Arc::new)
530                    .collect::<Vec<FieldRef>>();
531
532                let arguments = args
533                    .iter()
534                    .map(|e| match e {
535                        Expr::Literal(sv, _) => Some(sv),
536                        _ => None,
537                    })
538                    .collect::<Vec<_>>();
539                let args = ReturnFieldArgs {
540                    arg_fields: &new_fields,
541                    scalar_arguments: &arguments,
542                };
543
544                func.return_field_from_args(args)
545            }
546            // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)),
547            Expr::Cast(Cast { expr, data_type }) => expr
548                .to_field(schema)
549                .map(|(_, f)| f.as_ref().clone().with_data_type(data_type.clone()))
550                .map(Arc::new),
551            Expr::Like(_)
552            | Expr::SimilarTo(_)
553            | Expr::Not(_)
554            | Expr::Between(_)
555            | Expr::Case(_)
556            | Expr::TryCast(_)
557            | Expr::InList(_)
558            | Expr::InSubquery(_)
559            | Expr::Wildcard { .. }
560            | Expr::GroupingSet(_)
561            | Expr::Placeholder(_)
562            | Expr::Unnest(_) => Ok(Arc::new(Field::new(
563                &schema_name,
564                self.get_type(schema)?,
565                self.nullable(schema)?,
566            ))),
567        }?;
568
569        Ok((
570            relation,
571            Arc::new(field.as_ref().clone().with_name(schema_name)),
572        ))
573    }
574
575    /// Wraps this expression in a cast to a target [arrow::datatypes::DataType].
576    ///
577    /// # Errors
578    ///
579    /// This function errors when it is impossible to cast the
580    /// expression to the target [arrow::datatypes::DataType].
581    fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr> {
582        let this_type = self.get_type(schema)?;
583        if this_type == *cast_to_type {
584            return Ok(self);
585        }
586
587        // TODO(kszucs): Most of the operations do not validate the type correctness
588        // like all of the binary expressions below. Perhaps Expr should track the
589        // type of the expression?
590
591        if can_cast_types(&this_type, cast_to_type) {
592            match self {
593                Expr::ScalarSubquery(subquery) => {
594                    Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?))
595                }
596                _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))),
597            }
598        } else {
599            plan_err!("Cannot automatically convert {this_type:?} to {cast_to_type:?}")
600        }
601    }
602}
603
604impl Expr {
605    /// Common method for window functions that applies type coercion
606    /// to all arguments of the window function to check if it matches
607    /// its signature.
608    ///
609    /// If successful, this method returns the data type and
610    /// nullability of the window function's result.
611    ///
612    /// Otherwise, returns an error if there's a type mismatch between
613    /// the window function's signature and the provided arguments.
614    fn data_type_and_nullable_with_window_function(
615        &self,
616        schema: &dyn ExprSchema,
617        window_function: &WindowFunction,
618    ) -> Result<(DataType, bool)> {
619        let WindowFunction {
620            fun,
621            params: WindowFunctionParams { args, .. },
622            ..
623        } = window_function;
624
625        let fields = args
626            .iter()
627            .map(|e| e.to_field(schema).map(|(_, f)| f))
628            .collect::<Result<Vec<_>>>()?;
629        match fun {
630            WindowFunctionDefinition::AggregateUDF(udaf) => {
631                let data_types = fields
632                    .iter()
633                    .map(|f| f.data_type())
634                    .cloned()
635                    .collect::<Vec<_>>();
636                let new_fields = fields_with_aggregate_udf(&fields, udaf)
637                    .map_err(|err| {
638                        plan_datafusion_err!(
639                            "{} {}",
640                            match err {
641                                DataFusionError::Plan(msg) => msg,
642                                err => err.to_string(),
643                            },
644                            utils::generate_signature_error_msg(
645                                fun.name(),
646                                fun.signature(),
647                                &data_types
648                            )
649                        )
650                    })?
651                    .into_iter()
652                    .collect::<Vec<_>>();
653
654                let return_field = udaf.return_field(&new_fields)?;
655
656                Ok((return_field.data_type().clone(), return_field.is_nullable()))
657            }
658            WindowFunctionDefinition::WindowUDF(udwf) => {
659                let data_types = fields
660                    .iter()
661                    .map(|f| f.data_type())
662                    .cloned()
663                    .collect::<Vec<_>>();
664                let new_fields = fields_with_window_udf(&fields, udwf)
665                    .map_err(|err| {
666                        plan_datafusion_err!(
667                            "{} {}",
668                            match err {
669                                DataFusionError::Plan(msg) => msg,
670                                err => err.to_string(),
671                            },
672                            utils::generate_signature_error_msg(
673                                fun.name(),
674                                fun.signature(),
675                                &data_types
676                            )
677                        )
678                    })?
679                    .into_iter()
680                    .collect::<Vec<_>>();
681                let (_, function_name) = self.qualified_name();
682                let field_args = WindowUDFFieldArgs::new(&new_fields, &function_name);
683
684                udwf.field(field_args)
685                    .map(|field| (field.data_type().clone(), field.is_nullable()))
686            }
687        }
688    }
689}
690
691/// Cast subquery in InSubquery/ScalarSubquery to a given type.
692///
693/// 1. **Projection plan**: If the subquery is a projection (i.e. a SELECT statement with specific
694///    columns), it casts the first expression in the projection to the target type and creates a
695///    new projection with the casted expression.
696/// 2. **Non-projection plan**: If the subquery isn't a projection, it adds a projection to the plan
697///    with the casted first column.
698///
699pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery> {
700    if subquery.subquery.schema().field(0).data_type() == cast_to_type {
701        return Ok(subquery);
702    }
703
704    let plan = subquery.subquery.as_ref();
705    let new_plan = match plan {
706        LogicalPlan::Projection(projection) => {
707            let cast_expr = projection.expr[0]
708                .clone()
709                .cast_to(cast_to_type, projection.input.schema())?;
710            LogicalPlan::Projection(Projection::try_new(
711                vec![cast_expr],
712                Arc::clone(&projection.input),
713            )?)
714        }
715        _ => {
716            let cast_expr = Expr::Column(Column::from(plan.schema().qualified_field(0)))
717                .cast_to(cast_to_type, subquery.subquery.schema())?;
718            LogicalPlan::Projection(Projection::try_new(
719                vec![cast_expr],
720                subquery.subquery,
721            )?)
722        }
723    };
724    Ok(Subquery {
725        subquery: Arc::new(new_plan),
726        outer_ref_columns: subquery.outer_ref_columns,
727        spans: Spans::new(),
728    })
729}
730
731#[cfg(test)]
732mod tests {
733    use super::*;
734    use crate::{col, lit};
735
736    use datafusion_common::{internal_err, DFSchema, HashMap, ScalarValue};
737
738    macro_rules! test_is_expr_nullable {
739        ($EXPR_TYPE:ident) => {{
740            let expr = lit(ScalarValue::Null).$EXPR_TYPE();
741            assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
742        }};
743    }
744
745    #[test]
746    fn expr_schema_nullability() {
747        let expr = col("foo").eq(lit(1));
748        assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
749        assert!(expr
750            .nullable(&MockExprSchema::new().with_nullable(true))
751            .unwrap());
752
753        test_is_expr_nullable!(is_null);
754        test_is_expr_nullable!(is_not_null);
755        test_is_expr_nullable!(is_true);
756        test_is_expr_nullable!(is_not_true);
757        test_is_expr_nullable!(is_false);
758        test_is_expr_nullable!(is_not_false);
759        test_is_expr_nullable!(is_unknown);
760        test_is_expr_nullable!(is_not_unknown);
761    }
762
763    #[test]
764    fn test_between_nullability() {
765        let get_schema = |nullable| {
766            MockExprSchema::new()
767                .with_data_type(DataType::Int32)
768                .with_nullable(nullable)
769        };
770
771        let expr = col("foo").between(lit(1), lit(2));
772        assert!(!expr.nullable(&get_schema(false)).unwrap());
773        assert!(expr.nullable(&get_schema(true)).unwrap());
774
775        let null = lit(ScalarValue::Int32(None));
776
777        let expr = col("foo").between(null.clone(), lit(2));
778        assert!(expr.nullable(&get_schema(false)).unwrap());
779
780        let expr = col("foo").between(lit(1), null.clone());
781        assert!(expr.nullable(&get_schema(false)).unwrap());
782
783        let expr = col("foo").between(null.clone(), null);
784        assert!(expr.nullable(&get_schema(false)).unwrap());
785    }
786
787    #[test]
788    fn test_inlist_nullability() {
789        let get_schema = |nullable| {
790            MockExprSchema::new()
791                .with_data_type(DataType::Int32)
792                .with_nullable(nullable)
793        };
794
795        let expr = col("foo").in_list(vec![lit(1); 5], false);
796        assert!(!expr.nullable(&get_schema(false)).unwrap());
797        assert!(expr.nullable(&get_schema(true)).unwrap());
798        // Testing nullable() returns an error.
799        assert!(expr
800            .nullable(&get_schema(false).with_error_on_nullable(true))
801            .is_err());
802
803        let null = lit(ScalarValue::Int32(None));
804        let expr = col("foo").in_list(vec![null, lit(1)], false);
805        assert!(expr.nullable(&get_schema(false)).unwrap());
806
807        // Testing on long list
808        let expr = col("foo").in_list(vec![lit(1); 6], false);
809        assert!(expr.nullable(&get_schema(false)).unwrap());
810    }
811
812    #[test]
813    fn test_like_nullability() {
814        let get_schema = |nullable| {
815            MockExprSchema::new()
816                .with_data_type(DataType::Utf8)
817                .with_nullable(nullable)
818        };
819
820        let expr = col("foo").like(lit("bar"));
821        assert!(!expr.nullable(&get_schema(false)).unwrap());
822        assert!(expr.nullable(&get_schema(true)).unwrap());
823
824        let expr = col("foo").like(lit(ScalarValue::Utf8(None)));
825        assert!(expr.nullable(&get_schema(false)).unwrap());
826    }
827
828    #[test]
829    fn expr_schema_data_type() {
830        let expr = col("foo");
831        assert_eq!(
832            DataType::Utf8,
833            expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8))
834                .unwrap()
835        );
836    }
837
838    #[test]
839    fn test_expr_metadata() {
840        let mut meta = HashMap::new();
841        meta.insert("bar".to_string(), "buzz".to_string());
842        let meta = FieldMetadata::from(meta);
843        let expr = col("foo");
844        let schema = MockExprSchema::new()
845            .with_data_type(DataType::Int32)
846            .with_metadata(meta.clone());
847
848        // col, alias, and cast should be metadata-preserving
849        assert_eq!(meta, expr.metadata(&schema).unwrap());
850        assert_eq!(meta, expr.clone().alias("bar").metadata(&schema).unwrap());
851        assert_eq!(
852            meta,
853            expr.clone()
854                .cast_to(&DataType::Int64, &schema)
855                .unwrap()
856                .metadata(&schema)
857                .unwrap()
858        );
859
860        let schema = DFSchema::from_unqualified_fields(
861            vec![meta.add_to_field(Field::new("foo", DataType::Int32, true))].into(),
862            std::collections::HashMap::new(),
863        )
864        .unwrap();
865
866        // verify to_field method populates metadata
867        assert_eq!(meta, expr.metadata(&schema).unwrap());
868    }
869
870    #[derive(Debug)]
871    struct MockExprSchema {
872        field: Field,
873        error_on_nullable: bool,
874    }
875
876    impl MockExprSchema {
877        fn new() -> Self {
878            Self {
879                field: Field::new("mock_field", DataType::Null, false),
880                error_on_nullable: false,
881            }
882        }
883
884        fn with_nullable(mut self, nullable: bool) -> Self {
885            self.field = self.field.with_nullable(nullable);
886            self
887        }
888
889        fn with_data_type(mut self, data_type: DataType) -> Self {
890            self.field = self.field.with_data_type(data_type);
891            self
892        }
893
894        fn with_error_on_nullable(mut self, error_on_nullable: bool) -> Self {
895            self.error_on_nullable = error_on_nullable;
896            self
897        }
898
899        fn with_metadata(mut self, metadata: FieldMetadata) -> Self {
900            self.field = metadata.add_to_field(self.field);
901            self
902        }
903    }
904
905    impl ExprSchema for MockExprSchema {
906        fn nullable(&self, _col: &Column) -> Result<bool> {
907            if self.error_on_nullable {
908                internal_err!("nullable error")
909            } else {
910                Ok(self.field.is_nullable())
911            }
912        }
913
914        fn field_from_column(&self, _col: &Column) -> Result<&Field> {
915            Ok(&self.field)
916        }
917    }
918}