polars_plan/dsl/function_expr/
schema.rs

1use polars_core::utils::materialize_dyn_int;
2
3use super::*;
4
5impl FunctionExpr {
6    pub(crate) fn get_field(
7        &self,
8        _input_schema: &Schema,
9        _cntxt: Context,
10        fields: &[Field],
11    ) -> PolarsResult<Field> {
12        use FunctionExpr::*;
13
14        let mapper = FieldsMapper { fields };
15        match self {
16            // Namespaces
17            #[cfg(feature = "dtype-array")]
18            ArrayExpr(func) => func.get_field(mapper),
19            BinaryExpr(s) => s.get_field(mapper),
20            #[cfg(feature = "dtype-categorical")]
21            Categorical(func) => func.get_field(mapper),
22            ListExpr(func) => func.get_field(mapper),
23            #[cfg(feature = "strings")]
24            StringExpr(s) => s.get_field(mapper),
25            #[cfg(feature = "dtype-struct")]
26            StructExpr(s) => s.get_field(mapper),
27            #[cfg(feature = "temporal")]
28            TemporalExpr(fun) => fun.get_field(mapper),
29            #[cfg(feature = "bitwise")]
30            Bitwise(fun) => fun.get_field(mapper),
31
32            // Other expressions
33            Boolean(func) => func.get_field(mapper),
34            #[cfg(feature = "business")]
35            Business(func) => func.get_field(mapper),
36            #[cfg(feature = "abs")]
37            Abs => mapper.with_same_dtype(),
38            Negate => mapper.with_same_dtype(),
39            NullCount => mapper.with_dtype(IDX_DTYPE),
40            Pow(pow_function) => match pow_function {
41                PowFunction::Generic => mapper.pow_dtype(),
42                _ => mapper.map_to_float_dtype(),
43            },
44            Coalesce => mapper.map_to_supertype(),
45            #[cfg(feature = "row_hash")]
46            Hash(..) => mapper.with_dtype(DataType::UInt64),
47            #[cfg(feature = "arg_where")]
48            ArgWhere => mapper.with_dtype(IDX_DTYPE),
49            #[cfg(feature = "index_of")]
50            IndexOf => mapper.with_dtype(IDX_DTYPE),
51            #[cfg(feature = "search_sorted")]
52            SearchSorted(_) => mapper.with_dtype(IDX_DTYPE),
53            #[cfg(feature = "range")]
54            Range(func) => func.get_field(mapper),
55            #[cfg(feature = "trigonometry")]
56            Trigonometry(_) => mapper.map_to_float_dtype(),
57            #[cfg(feature = "trigonometry")]
58            Atan2 => mapper.map_to_float_dtype(),
59            #[cfg(feature = "sign")]
60            Sign => mapper.with_dtype(DataType::Int64),
61            FillNull  => mapper.map_to_supertype(),
62            #[cfg(feature = "rolling_window")]
63            RollingExpr(rolling_func, ..) => {
64                use RollingFunction::*;
65                match rolling_func {
66                    Min(_) | Max(_) => mapper.with_same_dtype(),
67                    Mean(_) | Quantile(_) | Var(_) | Std(_) => mapper.map_to_float_dtype(),
68                    Sum(_) => mapper.sum_dtype(),
69                    #[cfg(feature = "cov")]
70                    CorrCov {..} => mapper.map_to_float_dtype(),
71                    #[cfg(feature = "moment")]
72                    Skew(..) | Kurtosis(..) => mapper.map_to_float_dtype(),
73                }
74            },
75            #[cfg(feature = "rolling_window_by")]
76            RollingExprBy(rolling_func, ..) => {
77                use RollingFunctionBy::*;
78                match rolling_func {
79                    MinBy(_) | MaxBy(_) => mapper.with_same_dtype(),
80                    MeanBy(_) | QuantileBy(_) | VarBy(_) | StdBy(_) => mapper.map_to_float_dtype(),
81                    SumBy(_) => mapper.sum_dtype(),
82                }
83            },
84            ShiftAndFill => mapper.with_same_dtype(),
85            DropNans => mapper.with_same_dtype(),
86            DropNulls => mapper.with_same_dtype(),
87            #[cfg(feature = "round_series")]
88            Clip { .. } => mapper.with_same_dtype(),
89            #[cfg(feature = "mode")]
90            Mode => mapper.with_same_dtype(),
91            #[cfg(feature = "moment")]
92            Skew(_) => mapper.with_dtype(DataType::Float64),
93            #[cfg(feature = "moment")]
94            Kurtosis(..) => mapper.with_dtype(DataType::Float64),
95            ArgUnique => mapper.with_dtype(IDX_DTYPE),
96            Repeat => mapper.with_same_dtype(),
97            #[cfg(feature = "rank")]
98            Rank { options, .. } => mapper.with_dtype(match options.method {
99                RankMethod::Average => DataType::Float64,
100                _ => IDX_DTYPE,
101            }),
102            #[cfg(feature = "dtype-struct")]
103            AsStruct => Ok(Field::new(
104                fields[0].name().clone(),
105                DataType::Struct(fields.to_vec()),
106            )),
107            #[cfg(feature = "top_k")]
108            TopK { .. } => mapper.with_same_dtype(),
109            #[cfg(feature = "top_k")]
110            TopKBy { .. } => mapper.with_same_dtype(),
111            #[cfg(feature = "dtype-struct")]
112            ValueCounts {
113                sort: _,
114                parallel: _,
115                name,
116                normalize,
117            } => mapper.map_dtype(|dt| {
118                let count_dt = if *normalize {
119                    DataType::Float64
120                } else {
121                    IDX_DTYPE
122                };
123                DataType::Struct(vec![
124                    Field::new(fields[0].name().clone(), dt.clone()),
125                    Field::new(name.clone(), count_dt),
126                ])
127            }),
128            #[cfg(feature = "unique_counts")]
129            UniqueCounts => mapper.with_dtype(IDX_DTYPE),
130            Shift | Reverse => mapper.with_same_dtype(),
131            #[cfg(feature = "cum_agg")]
132            CumCount { .. } => mapper.with_dtype(IDX_DTYPE),
133            #[cfg(feature = "cum_agg")]
134            CumSum { .. } => mapper.map_dtype(cum::dtypes::cum_sum),
135            #[cfg(feature = "cum_agg")]
136            CumProd { .. } => mapper.map_dtype(cum::dtypes::cum_prod),
137            #[cfg(feature = "cum_agg")]
138            CumMin { .. } => mapper.with_same_dtype(),
139            #[cfg(feature = "cum_agg")]
140            CumMax { .. } => mapper.with_same_dtype(),
141            #[cfg(feature = "approx_unique")]
142            ApproxNUnique => mapper.with_dtype(IDX_DTYPE),
143            #[cfg(feature = "hist")]
144            Hist {
145                include_category,
146                include_breakpoint,
147                ..
148            } => {
149                if *include_breakpoint || *include_category {
150                    let mut fields = Vec::with_capacity(3);
151                    if *include_breakpoint {
152                        fields.push(Field::new(
153                            PlSmallStr::from_static("breakpoint"),
154                            DataType::Float64,
155                        ));
156                    }
157                    if *include_category {
158                        fields.push(Field::new(
159                            PlSmallStr::from_static("category"),
160                            DataType::Categorical(None, Default::default()),
161                        ));
162                    }
163                    fields.push(Field::new(PlSmallStr::from_static("count"), IDX_DTYPE));
164                    mapper.with_dtype(DataType::Struct(fields))
165                } else {
166                    mapper.with_dtype(IDX_DTYPE)
167                }
168            },
169            #[cfg(feature = "diff")]
170            Diff(_) => mapper.map_dtype(|dt| match dt {
171                #[cfg(feature = "dtype-datetime")]
172                DataType::Datetime(tu, _) => DataType::Duration(*tu),
173                #[cfg(feature = "dtype-date")]
174                DataType::Date => DataType::Duration(TimeUnit::Milliseconds),
175                #[cfg(feature = "dtype-time")]
176                DataType::Time => DataType::Duration(TimeUnit::Nanoseconds),
177                DataType::UInt64 | DataType::UInt32 => DataType::Int64,
178                DataType::UInt16 => DataType::Int32,
179                DataType::UInt8 => DataType::Int16,
180                dt => dt.clone(),
181            }),
182            #[cfg(feature = "pct_change")]
183            PctChange => mapper.map_dtype(|dt| match dt {
184                DataType::Float64 | DataType::Float32 => dt.clone(),
185                _ => DataType::Float64,
186            }),
187            #[cfg(feature = "interpolate")]
188            Interpolate(method) => match method {
189                InterpolationMethod::Linear => mapper.map_numeric_to_float_dtype(),
190                InterpolationMethod::Nearest => mapper.with_same_dtype(),
191            },
192            #[cfg(feature = "interpolate_by")]
193            InterpolateBy => mapper.map_numeric_to_float_dtype(),
194            ShrinkType => {
195                // we return the smallest type this can return
196                // this might not be correct once the actual data
197                // comes in, but if we set the smallest datatype
198                // we have the least chance that the smaller dtypes
199                // get cast to larger types in type-coercion
200                // this will lead to an incorrect schema in polars
201                // but we because only the numeric types deviate in
202                // bit size this will likely not lead to issues
203                mapper.map_dtype(|dt| {
204                    if dt.is_primitive_numeric() {
205                        if dt.is_float() {
206                            DataType::Float32
207                        } else if dt.is_unsigned_integer() {
208                            DataType::Int8
209                        } else {
210                            DataType::UInt8
211                        }
212                    } else {
213                        dt.clone()
214                    }
215                })
216            },
217            #[cfg(feature = "log")]
218            Entropy { .. } | Log { .. } | Log1p | Exp => mapper.map_to_float_dtype(),
219            Unique(_) => mapper.with_same_dtype(),
220            #[cfg(feature = "round_series")]
221            Round { .. } | RoundSF { .. } | Floor | Ceil => mapper.with_same_dtype(),
222            UpperBound | LowerBound => mapper.with_same_dtype(),
223            #[cfg(feature = "fused")]
224            Fused(_) => mapper.map_to_supertype(),
225            ConcatExpr(_) => mapper.map_to_supertype(),
226            #[cfg(feature = "cov")]
227            Correlation { .. } => mapper.map_to_float_dtype(),
228            #[cfg(feature = "peaks")]
229            PeakMin => mapper.with_same_dtype(),
230            #[cfg(feature = "peaks")]
231            PeakMax => mapper.with_same_dtype(),
232            #[cfg(feature = "cutqcut")]
233            Cut {
234                include_breaks: false,
235                ..
236            } => mapper.with_dtype(DataType::Categorical(None, Default::default())),
237            #[cfg(feature = "cutqcut")]
238            Cut {
239                include_breaks: true,
240                ..
241            } => {
242                let struct_dt = DataType::Struct(vec![
243                    Field::new(PlSmallStr::from_static("breakpoint"), DataType::Float64),
244                    Field::new(
245                        PlSmallStr::from_static("category"),
246                        DataType::Categorical(None, Default::default()),
247                    ),
248                ]);
249                mapper.with_dtype(struct_dt)
250            },
251            #[cfg(feature = "repeat_by")]
252            RepeatBy => mapper.map_dtype(|dt| DataType::List(dt.clone().into())),
253            #[cfg(feature = "dtype-array")]
254            Reshape(dims) => mapper.try_map_dtype(|dt: &DataType| {
255                let dtype = dt.inner_dtype().unwrap_or(dt).clone();
256
257                if dims.len() == 1 {
258                    return Ok(dtype);
259                }
260
261                let num_infers = dims.iter().filter(|d| matches!(d, ReshapeDimension::Infer)).count();
262
263                polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension");
264
265                let mut inferred_size = 0;
266                if num_infers == 1 {
267                    let mut total_size = 1u64;
268                    let mut current = dt;
269                    while let DataType::Array(dt, width) = current {
270                        if *width == 0 {
271                            total_size = 0;
272                            break;
273                        }
274
275                        current = dt.as_ref();
276                        total_size *= *width as u64;
277                    }
278
279                    let current_size = dims.iter().map(|d| d.get_or_infer(1)).product::<u64>();
280                    inferred_size = total_size / current_size;
281                }
282
283                let mut prev_dtype = dtype.leaf_dtype().clone();
284
285                // We pop the outer dimension as that is the height of the series.
286                for dim in &dims[1..] {
287                    prev_dtype = DataType::Array(Box::new(prev_dtype), dim.get_or_infer(inferred_size) as usize);
288                }
289                Ok(prev_dtype)
290            }),
291            #[cfg(feature = "cutqcut")]
292            QCut {
293                include_breaks: false,
294                ..
295            } => mapper.with_dtype(DataType::Categorical(None, Default::default())),
296            #[cfg(feature = "cutqcut")]
297            QCut {
298                include_breaks: true,
299                ..
300            } => {
301                let struct_dt = DataType::Struct(vec![
302                    Field::new(PlSmallStr::from_static("breakpoint"), DataType::Float64),
303                    Field::new(
304                        PlSmallStr::from_static("category"),
305                        DataType::Categorical(None, Default::default()),
306                    ),
307                ]);
308                mapper.with_dtype(struct_dt)
309            },
310            #[cfg(feature = "rle")]
311            RLE => mapper.map_dtype(|dt| {
312                DataType::Struct(vec![
313                    Field::new(PlSmallStr::from_static("len"), IDX_DTYPE),
314                    Field::new(PlSmallStr::from_static("value"), dt.clone()),
315                ])
316            }),
317            #[cfg(feature = "rle")]
318            RLEID => mapper.with_dtype(IDX_DTYPE),
319            ToPhysical => mapper.to_physical_type(),
320            #[cfg(feature = "random")]
321            Random { .. } => mapper.with_same_dtype(),
322            SetSortedFlag(_) => mapper.with_same_dtype(),
323            #[cfg(feature = "ffi_plugin")]
324            FfiPlugin {
325                flags: _,
326                lib,
327                symbol,
328                kwargs,
329            } => unsafe { plugin::plugin_field(fields, lib, symbol.as_ref(), kwargs) },
330            MaxHorizontal => mapper.map_to_supertype(),
331            MinHorizontal => mapper.map_to_supertype(),
332            SumHorizontal { .. } => {
333                mapper.map_to_supertype().map(|mut f| {
334                    if f.dtype == DataType::Boolean {
335                        f.dtype = IDX_DTYPE;
336                    }
337                    f
338                })
339            },
340            MeanHorizontal { .. } => {
341                mapper.map_to_supertype().map(|mut f| {
342                    match f.dtype {
343                        dt @ DataType::Float32 => { f.dtype = dt; },
344                        _ => { f.dtype = DataType::Float64; },
345                    };
346                    f
347                })
348            }
349            #[cfg(feature = "ewma")]
350            EwmMean { .. } => mapper.map_to_float_dtype(),
351            #[cfg(feature = "ewma_by")]
352            EwmMeanBy { .. } => mapper.map_to_float_dtype(),
353            #[cfg(feature = "ewma")]
354            EwmStd { .. } => mapper.map_to_float_dtype(),
355            #[cfg(feature = "ewma")]
356            EwmVar { .. } => mapper.map_to_float_dtype(),
357            #[cfg(feature = "replace")]
358            Replace => mapper.with_same_dtype(),
359            #[cfg(feature = "replace")]
360            ReplaceStrict { return_dtype } => mapper.replace_dtype(return_dtype.clone()),
361            FillNullWithStrategy(_) => mapper.with_same_dtype(),
362            GatherEvery { .. } => mapper.with_same_dtype(),
363            #[cfg(feature = "reinterpret")]
364            Reinterpret(signed) => {
365                let dt = if *signed {
366                    DataType::Int64
367                } else {
368                    DataType::UInt64
369                };
370                mapper.with_dtype(dt)
371            },
372            ExtendConstant => mapper.with_same_dtype(),
373        }
374    }
375
376    pub(crate) fn output_name(&self) -> Option<OutputName> {
377        match self {
378            #[cfg(feature = "dtype-struct")]
379            FunctionExpr::StructExpr(StructFunction::FieldByName(name)) => {
380                Some(OutputName::Field(name.clone()))
381            },
382            _ => None,
383        }
384    }
385}
386
387pub struct FieldsMapper<'a> {
388    fields: &'a [Field],
389}
390
391impl<'a> FieldsMapper<'a> {
392    pub fn new(fields: &'a [Field]) -> Self {
393        Self { fields }
394    }
395
396    pub fn args(&self) -> &[Field] {
397        self.fields
398    }
399
400    /// Field with the same dtype.
401    pub fn with_same_dtype(&self) -> PolarsResult<Field> {
402        self.map_dtype(|dtype| dtype.clone())
403    }
404
405    /// Set a dtype.
406    pub fn with_dtype(&self, dtype: DataType) -> PolarsResult<Field> {
407        Ok(Field::new(self.fields[0].name().clone(), dtype))
408    }
409
410    /// Map a single dtype.
411    pub fn map_dtype(&self, func: impl FnOnce(&DataType) -> DataType) -> PolarsResult<Field> {
412        let dtype = func(self.fields[0].dtype());
413        Ok(Field::new(self.fields[0].name().clone(), dtype))
414    }
415
416    pub fn get_fields_lens(&self) -> usize {
417        self.fields.len()
418    }
419
420    /// Map a single field with a potentially failing mapper function.
421    pub fn try_map_field(
422        &self,
423        func: impl FnOnce(&Field) -> PolarsResult<Field>,
424    ) -> PolarsResult<Field> {
425        func(&self.fields[0])
426    }
427
428    /// Map to a float supertype.
429    pub fn map_to_float_dtype(&self) -> PolarsResult<Field> {
430        self.map_dtype(|dtype| match dtype {
431            DataType::Float32 => DataType::Float32,
432            _ => DataType::Float64,
433        })
434    }
435
436    /// Map to a float supertype if numeric, else preserve
437    pub fn map_numeric_to_float_dtype(&self) -> PolarsResult<Field> {
438        self.map_dtype(|dtype| {
439            if dtype.is_primitive_numeric() {
440                match dtype {
441                    DataType::Float32 => DataType::Float32,
442                    _ => DataType::Float64,
443                }
444            } else {
445                dtype.clone()
446            }
447        })
448    }
449
450    /// Map to a physical type.
451    pub fn to_physical_type(&self) -> PolarsResult<Field> {
452        self.map_dtype(|dtype| dtype.to_physical())
453    }
454
455    /// Map a single dtype with a potentially failing mapper function.
456    pub fn try_map_dtype(
457        &self,
458        func: impl FnOnce(&DataType) -> PolarsResult<DataType>,
459    ) -> PolarsResult<Field> {
460        let dtype = func(self.fields[0].dtype())?;
461        Ok(Field::new(self.fields[0].name().clone(), dtype))
462    }
463
464    /// Map all dtypes with a potentially failing mapper function.
465    pub fn try_map_dtypes(
466        &self,
467        func: impl FnOnce(&[&DataType]) -> PolarsResult<DataType>,
468    ) -> PolarsResult<Field> {
469        let mut fld = self.fields[0].clone();
470        let dtypes = self
471            .fields
472            .iter()
473            .map(|fld| fld.dtype())
474            .collect::<Vec<_>>();
475        let new_type = func(&dtypes)?;
476        fld.coerce(new_type);
477        Ok(fld)
478    }
479
480    /// Map the dtype to the "supertype" of all fields.
481    pub fn map_to_supertype(&self) -> PolarsResult<Field> {
482        let st = args_to_supertype(self.fields)?;
483        let mut first = self.fields[0].clone();
484        first.coerce(st);
485        Ok(first)
486    }
487
488    /// Map the dtype to the dtype of the list/array elements.
489    pub fn map_to_list_and_array_inner_dtype(&self) -> PolarsResult<Field> {
490        let mut first = self.fields[0].clone();
491        let dt = first
492            .dtype()
493            .inner_dtype()
494            .cloned()
495            .unwrap_or_else(|| DataType::Unknown(Default::default()));
496        first.coerce(dt);
497        Ok(first)
498    }
499
500    #[cfg(feature = "dtype-array")]
501    /// Map the dtype to the dtype of the array elements, with typo validation.
502    pub fn try_map_to_array_inner_dtype(&self) -> PolarsResult<Field> {
503        let dt = self.fields[0].dtype();
504        match dt {
505            DataType::Array(_, _) => self.map_to_list_and_array_inner_dtype(),
506            _ => polars_bail!(InvalidOperation: "expected Array type, got: {}", dt),
507        }
508    }
509
510    /// Map the dtypes to the "supertype" of a list of lists.
511    pub fn map_to_list_supertype(&self) -> PolarsResult<Field> {
512        self.try_map_dtypes(|dts| {
513            let mut super_type_inner = None;
514
515            for dt in dts {
516                match dt {
517                    DataType::List(inner) => match super_type_inner {
518                        None => super_type_inner = Some(*inner.clone()),
519                        Some(st_inner) => {
520                            super_type_inner = Some(try_get_supertype(&st_inner, inner)?)
521                        },
522                    },
523                    dt => match super_type_inner {
524                        None => super_type_inner = Some((*dt).clone()),
525                        Some(st_inner) => {
526                            super_type_inner = Some(try_get_supertype(&st_inner, dt)?)
527                        },
528                    },
529                }
530            }
531            Ok(DataType::List(Box::new(super_type_inner.unwrap())))
532        })
533    }
534
535    /// Set the timezone of a datetime dtype.
536    #[cfg(feature = "timezones")]
537    pub fn map_datetime_dtype_timezone(&self, tz: Option<&TimeZone>) -> PolarsResult<Field> {
538        self.try_map_dtype(|dt| {
539            if let DataType::Datetime(tu, _) = dt {
540                Ok(DataType::Datetime(*tu, tz.cloned()))
541            } else {
542                polars_bail!(op = "replace-time-zone", got = dt, expected = "Datetime");
543            }
544        })
545    }
546
547    pub fn sum_dtype(&self) -> PolarsResult<Field> {
548        use DataType::*;
549        self.map_dtype(|dtype| match dtype {
550            Int8 | UInt8 | Int16 | UInt16 => Int64,
551            dt => dt.clone(),
552        })
553    }
554
555    pub fn nested_sum_type(&self) -> PolarsResult<Field> {
556        let mut first = self.fields[0].clone();
557        use DataType::*;
558        let dt = first
559            .dtype()
560            .inner_dtype()
561            .cloned()
562            .unwrap_or_else(|| Unknown(Default::default()));
563
564        match dt {
565            Boolean => first.coerce(IDX_DTYPE),
566            UInt8 | Int8 | Int16 | UInt16 => first.coerce(Int64),
567            _ => first.coerce(dt),
568        }
569        Ok(first)
570    }
571
572    pub fn nested_mean_median_type(&self) -> PolarsResult<Field> {
573        let mut first = self.fields[0].clone();
574        use DataType::*;
575        let dt = first
576            .dtype()
577            .inner_dtype()
578            .cloned()
579            .unwrap_or_else(|| Unknown(Default::default()));
580
581        let new_dt = match dt {
582            #[cfg(feature = "dtype-datetime")]
583            Date => Datetime(TimeUnit::Milliseconds, None),
584            dt if dt.is_temporal() => dt,
585            Float32 => Float32,
586            _ => Float64,
587        };
588        first.coerce(new_dt);
589        Ok(first)
590    }
591
592    pub(super) fn pow_dtype(&self) -> PolarsResult<Field> {
593        let base_dtype = self.fields[0].dtype();
594        let exponent_dtype = self.fields[1].dtype();
595        if base_dtype.is_integer() {
596            if exponent_dtype.is_float() {
597                Ok(Field::new(
598                    self.fields[0].name().clone(),
599                    exponent_dtype.clone(),
600                ))
601            } else {
602                Ok(Field::new(
603                    self.fields[0].name().clone(),
604                    base_dtype.clone(),
605                ))
606            }
607        } else {
608            Ok(Field::new(
609                self.fields[0].name().clone(),
610                base_dtype.clone(),
611            ))
612        }
613    }
614
615    #[cfg(feature = "extract_jsonpath")]
616    pub fn with_opt_dtype(&self, dtype: Option<DataType>) -> PolarsResult<Field> {
617        let dtype = dtype.unwrap_or_else(|| DataType::Unknown(Default::default()));
618        self.with_dtype(dtype)
619    }
620
621    #[cfg(feature = "replace")]
622    pub fn replace_dtype(&self, return_dtype: Option<DataType>) -> PolarsResult<Field> {
623        let dtype = match return_dtype {
624            Some(dtype) => dtype,
625            None => {
626                let new = &self.fields[2];
627                let default = self.fields.get(3);
628
629                // @HACK: Related to implicit implode see #22149.
630                let inner_dtype = new.dtype().inner_dtype().unwrap_or(new.dtype());
631
632                match default {
633                    Some(default) => try_get_supertype(default.dtype(), inner_dtype)?,
634                    None => inner_dtype.clone(),
635                }
636            },
637        };
638        self.with_dtype(dtype)
639    }
640}
641
642pub(crate) fn args_to_supertype<D: AsRef<DataType>>(dtypes: &[D]) -> PolarsResult<DataType> {
643    let mut st = dtypes[0].as_ref().clone();
644    for dt in &dtypes[1..] {
645        st = try_get_supertype(&st, dt.as_ref())?
646    }
647
648    match (dtypes[0].as_ref(), &st) {
649        #[cfg(feature = "dtype-categorical")]
650        (DataType::Categorical(_, ord), DataType::String) => st = DataType::Categorical(None, *ord),
651        _ => {
652            if let DataType::Unknown(kind) = st {
653                match kind {
654                    UnknownKind::Float => st = DataType::Float64,
655                    UnknownKind::Int(v) => {
656                        st = materialize_dyn_int(v).dtype();
657                    },
658                    UnknownKind::Str => st = DataType::String,
659                    _ => {},
660                }
661            }
662        },
663    }
664
665    Ok(st)
666}