1use polars_core::utils::materialize_dyn_int;
2
3use super::*;
4
5impl FunctionExpr {
6 pub(crate) fn get_field(
7 &self,
8 _input_schema: &Schema,
9 _cntxt: Context,
10 fields: &[Field],
11 ) -> PolarsResult<Field> {
12 use FunctionExpr::*;
13
14 let mapper = FieldsMapper { fields };
15 match self {
16 #[cfg(feature = "dtype-array")]
18 ArrayExpr(func) => func.get_field(mapper),
19 BinaryExpr(s) => s.get_field(mapper),
20 #[cfg(feature = "dtype-categorical")]
21 Categorical(func) => func.get_field(mapper),
22 ListExpr(func) => func.get_field(mapper),
23 #[cfg(feature = "strings")]
24 StringExpr(s) => s.get_field(mapper),
25 #[cfg(feature = "dtype-struct")]
26 StructExpr(s) => s.get_field(mapper),
27 #[cfg(feature = "temporal")]
28 TemporalExpr(fun) => fun.get_field(mapper),
29 #[cfg(feature = "bitwise")]
30 Bitwise(fun) => fun.get_field(mapper),
31
32 Boolean(func) => func.get_field(mapper),
34 #[cfg(feature = "business")]
35 Business(func) => func.get_field(mapper),
36 #[cfg(feature = "abs")]
37 Abs => mapper.with_same_dtype(),
38 Negate => mapper.with_same_dtype(),
39 NullCount => mapper.with_dtype(IDX_DTYPE),
40 Pow(pow_function) => match pow_function {
41 PowFunction::Generic => mapper.pow_dtype(),
42 _ => mapper.map_to_float_dtype(),
43 },
44 Coalesce => mapper.map_to_supertype(),
45 #[cfg(feature = "row_hash")]
46 Hash(..) => mapper.with_dtype(DataType::UInt64),
47 #[cfg(feature = "arg_where")]
48 ArgWhere => mapper.with_dtype(IDX_DTYPE),
49 #[cfg(feature = "index_of")]
50 IndexOf => mapper.with_dtype(IDX_DTYPE),
51 #[cfg(feature = "search_sorted")]
52 SearchSorted(_) => mapper.with_dtype(IDX_DTYPE),
53 #[cfg(feature = "range")]
54 Range(func) => func.get_field(mapper),
55 #[cfg(feature = "trigonometry")]
56 Trigonometry(_) => mapper.map_to_float_dtype(),
57 #[cfg(feature = "trigonometry")]
58 Atan2 => mapper.map_to_float_dtype(),
59 #[cfg(feature = "sign")]
60 Sign => mapper.with_dtype(DataType::Int64),
61 FillNull => mapper.map_to_supertype(),
62 #[cfg(feature = "rolling_window")]
63 RollingExpr(rolling_func, ..) => {
64 use RollingFunction::*;
65 match rolling_func {
66 Min(_) | Max(_) => mapper.with_same_dtype(),
67 Mean(_) | Quantile(_) | Var(_) | Std(_) => mapper.map_to_float_dtype(),
68 Sum(_) => mapper.sum_dtype(),
69 #[cfg(feature = "cov")]
70 CorrCov {..} => mapper.map_to_float_dtype(),
71 #[cfg(feature = "moment")]
72 Skew(..) | Kurtosis(..) => mapper.map_to_float_dtype(),
73 }
74 },
75 #[cfg(feature = "rolling_window_by")]
76 RollingExprBy(rolling_func, ..) => {
77 use RollingFunctionBy::*;
78 match rolling_func {
79 MinBy(_) | MaxBy(_) => mapper.with_same_dtype(),
80 MeanBy(_) | QuantileBy(_) | VarBy(_) | StdBy(_) => mapper.map_to_float_dtype(),
81 SumBy(_) => mapper.sum_dtype(),
82 }
83 },
84 ShiftAndFill => mapper.with_same_dtype(),
85 DropNans => mapper.with_same_dtype(),
86 DropNulls => mapper.with_same_dtype(),
87 #[cfg(feature = "round_series")]
88 Clip { .. } => mapper.with_same_dtype(),
89 #[cfg(feature = "mode")]
90 Mode => mapper.with_same_dtype(),
91 #[cfg(feature = "moment")]
92 Skew(_) => mapper.with_dtype(DataType::Float64),
93 #[cfg(feature = "moment")]
94 Kurtosis(..) => mapper.with_dtype(DataType::Float64),
95 ArgUnique => mapper.with_dtype(IDX_DTYPE),
96 Repeat => mapper.with_same_dtype(),
97 #[cfg(feature = "rank")]
98 Rank { options, .. } => mapper.with_dtype(match options.method {
99 RankMethod::Average => DataType::Float64,
100 _ => IDX_DTYPE,
101 }),
102 #[cfg(feature = "dtype-struct")]
103 AsStruct => Ok(Field::new(
104 fields[0].name().clone(),
105 DataType::Struct(fields.to_vec()),
106 )),
107 #[cfg(feature = "top_k")]
108 TopK { .. } => mapper.with_same_dtype(),
109 #[cfg(feature = "top_k")]
110 TopKBy { .. } => mapper.with_same_dtype(),
111 #[cfg(feature = "dtype-struct")]
112 ValueCounts {
113 sort: _,
114 parallel: _,
115 name,
116 normalize,
117 } => mapper.map_dtype(|dt| {
118 let count_dt = if *normalize {
119 DataType::Float64
120 } else {
121 IDX_DTYPE
122 };
123 DataType::Struct(vec![
124 Field::new(fields[0].name().clone(), dt.clone()),
125 Field::new(name.clone(), count_dt),
126 ])
127 }),
128 #[cfg(feature = "unique_counts")]
129 UniqueCounts => mapper.with_dtype(IDX_DTYPE),
130 Shift | Reverse => mapper.with_same_dtype(),
131 #[cfg(feature = "cum_agg")]
132 CumCount { .. } => mapper.with_dtype(IDX_DTYPE),
133 #[cfg(feature = "cum_agg")]
134 CumSum { .. } => mapper.map_dtype(cum::dtypes::cum_sum),
135 #[cfg(feature = "cum_agg")]
136 CumProd { .. } => mapper.map_dtype(cum::dtypes::cum_prod),
137 #[cfg(feature = "cum_agg")]
138 CumMin { .. } => mapper.with_same_dtype(),
139 #[cfg(feature = "cum_agg")]
140 CumMax { .. } => mapper.with_same_dtype(),
141 #[cfg(feature = "approx_unique")]
142 ApproxNUnique => mapper.with_dtype(IDX_DTYPE),
143 #[cfg(feature = "hist")]
144 Hist {
145 include_category,
146 include_breakpoint,
147 ..
148 } => {
149 if *include_breakpoint || *include_category {
150 let mut fields = Vec::with_capacity(3);
151 if *include_breakpoint {
152 fields.push(Field::new(
153 PlSmallStr::from_static("breakpoint"),
154 DataType::Float64,
155 ));
156 }
157 if *include_category {
158 fields.push(Field::new(
159 PlSmallStr::from_static("category"),
160 DataType::Categorical(None, Default::default()),
161 ));
162 }
163 fields.push(Field::new(PlSmallStr::from_static("count"), IDX_DTYPE));
164 mapper.with_dtype(DataType::Struct(fields))
165 } else {
166 mapper.with_dtype(IDX_DTYPE)
167 }
168 },
169 #[cfg(feature = "diff")]
170 Diff(_) => mapper.map_dtype(|dt| match dt {
171 #[cfg(feature = "dtype-datetime")]
172 DataType::Datetime(tu, _) => DataType::Duration(*tu),
173 #[cfg(feature = "dtype-date")]
174 DataType::Date => DataType::Duration(TimeUnit::Milliseconds),
175 #[cfg(feature = "dtype-time")]
176 DataType::Time => DataType::Duration(TimeUnit::Nanoseconds),
177 DataType::UInt64 | DataType::UInt32 => DataType::Int64,
178 DataType::UInt16 => DataType::Int32,
179 DataType::UInt8 => DataType::Int16,
180 dt => dt.clone(),
181 }),
182 #[cfg(feature = "pct_change")]
183 PctChange => mapper.map_dtype(|dt| match dt {
184 DataType::Float64 | DataType::Float32 => dt.clone(),
185 _ => DataType::Float64,
186 }),
187 #[cfg(feature = "interpolate")]
188 Interpolate(method) => match method {
189 InterpolationMethod::Linear => mapper.map_numeric_to_float_dtype(),
190 InterpolationMethod::Nearest => mapper.with_same_dtype(),
191 },
192 #[cfg(feature = "interpolate_by")]
193 InterpolateBy => mapper.map_numeric_to_float_dtype(),
194 ShrinkType => {
195 mapper.map_dtype(|dt| {
204 if dt.is_primitive_numeric() {
205 if dt.is_float() {
206 DataType::Float32
207 } else if dt.is_unsigned_integer() {
208 DataType::Int8
209 } else {
210 DataType::UInt8
211 }
212 } else {
213 dt.clone()
214 }
215 })
216 },
217 #[cfg(feature = "log")]
218 Entropy { .. } | Log { .. } | Log1p | Exp => mapper.map_to_float_dtype(),
219 Unique(_) => mapper.with_same_dtype(),
220 #[cfg(feature = "round_series")]
221 Round { .. } | RoundSF { .. } | Floor | Ceil => mapper.with_same_dtype(),
222 UpperBound | LowerBound => mapper.with_same_dtype(),
223 #[cfg(feature = "fused")]
224 Fused(_) => mapper.map_to_supertype(),
225 ConcatExpr(_) => mapper.map_to_supertype(),
226 #[cfg(feature = "cov")]
227 Correlation { .. } => mapper.map_to_float_dtype(),
228 #[cfg(feature = "peaks")]
229 PeakMin => mapper.with_same_dtype(),
230 #[cfg(feature = "peaks")]
231 PeakMax => mapper.with_same_dtype(),
232 #[cfg(feature = "cutqcut")]
233 Cut {
234 include_breaks: false,
235 ..
236 } => mapper.with_dtype(DataType::Categorical(None, Default::default())),
237 #[cfg(feature = "cutqcut")]
238 Cut {
239 include_breaks: true,
240 ..
241 } => {
242 let struct_dt = DataType::Struct(vec![
243 Field::new(PlSmallStr::from_static("breakpoint"), DataType::Float64),
244 Field::new(
245 PlSmallStr::from_static("category"),
246 DataType::Categorical(None, Default::default()),
247 ),
248 ]);
249 mapper.with_dtype(struct_dt)
250 },
251 #[cfg(feature = "repeat_by")]
252 RepeatBy => mapper.map_dtype(|dt| DataType::List(dt.clone().into())),
253 #[cfg(feature = "dtype-array")]
254 Reshape(dims) => mapper.try_map_dtype(|dt: &DataType| {
255 let dtype = dt.inner_dtype().unwrap_or(dt).clone();
256
257 if dims.len() == 1 {
258 return Ok(dtype);
259 }
260
261 let num_infers = dims.iter().filter(|d| matches!(d, ReshapeDimension::Infer)).count();
262
263 polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension");
264
265 let mut inferred_size = 0;
266 if num_infers == 1 {
267 let mut total_size = 1u64;
268 let mut current = dt;
269 while let DataType::Array(dt, width) = current {
270 if *width == 0 {
271 total_size = 0;
272 break;
273 }
274
275 current = dt.as_ref();
276 total_size *= *width as u64;
277 }
278
279 let current_size = dims.iter().map(|d| d.get_or_infer(1)).product::<u64>();
280 inferred_size = total_size / current_size;
281 }
282
283 let mut prev_dtype = dtype.leaf_dtype().clone();
284
285 for dim in &dims[1..] {
287 prev_dtype = DataType::Array(Box::new(prev_dtype), dim.get_or_infer(inferred_size) as usize);
288 }
289 Ok(prev_dtype)
290 }),
291 #[cfg(feature = "cutqcut")]
292 QCut {
293 include_breaks: false,
294 ..
295 } => mapper.with_dtype(DataType::Categorical(None, Default::default())),
296 #[cfg(feature = "cutqcut")]
297 QCut {
298 include_breaks: true,
299 ..
300 } => {
301 let struct_dt = DataType::Struct(vec![
302 Field::new(PlSmallStr::from_static("breakpoint"), DataType::Float64),
303 Field::new(
304 PlSmallStr::from_static("category"),
305 DataType::Categorical(None, Default::default()),
306 ),
307 ]);
308 mapper.with_dtype(struct_dt)
309 },
310 #[cfg(feature = "rle")]
311 RLE => mapper.map_dtype(|dt| {
312 DataType::Struct(vec![
313 Field::new(PlSmallStr::from_static("len"), IDX_DTYPE),
314 Field::new(PlSmallStr::from_static("value"), dt.clone()),
315 ])
316 }),
317 #[cfg(feature = "rle")]
318 RLEID => mapper.with_dtype(IDX_DTYPE),
319 ToPhysical => mapper.to_physical_type(),
320 #[cfg(feature = "random")]
321 Random { .. } => mapper.with_same_dtype(),
322 SetSortedFlag(_) => mapper.with_same_dtype(),
323 #[cfg(feature = "ffi_plugin")]
324 FfiPlugin {
325 flags: _,
326 lib,
327 symbol,
328 kwargs,
329 } => unsafe { plugin::plugin_field(fields, lib, symbol.as_ref(), kwargs) },
330 MaxHorizontal => mapper.map_to_supertype(),
331 MinHorizontal => mapper.map_to_supertype(),
332 SumHorizontal { .. } => {
333 mapper.map_to_supertype().map(|mut f| {
334 if f.dtype == DataType::Boolean {
335 f.dtype = IDX_DTYPE;
336 }
337 f
338 })
339 },
340 MeanHorizontal { .. } => {
341 mapper.map_to_supertype().map(|mut f| {
342 match f.dtype {
343 dt @ DataType::Float32 => { f.dtype = dt; },
344 _ => { f.dtype = DataType::Float64; },
345 };
346 f
347 })
348 }
349 #[cfg(feature = "ewma")]
350 EwmMean { .. } => mapper.map_to_float_dtype(),
351 #[cfg(feature = "ewma_by")]
352 EwmMeanBy { .. } => mapper.map_to_float_dtype(),
353 #[cfg(feature = "ewma")]
354 EwmStd { .. } => mapper.map_to_float_dtype(),
355 #[cfg(feature = "ewma")]
356 EwmVar { .. } => mapper.map_to_float_dtype(),
357 #[cfg(feature = "replace")]
358 Replace => mapper.with_same_dtype(),
359 #[cfg(feature = "replace")]
360 ReplaceStrict { return_dtype } => mapper.replace_dtype(return_dtype.clone()),
361 FillNullWithStrategy(_) => mapper.with_same_dtype(),
362 GatherEvery { .. } => mapper.with_same_dtype(),
363 #[cfg(feature = "reinterpret")]
364 Reinterpret(signed) => {
365 let dt = if *signed {
366 DataType::Int64
367 } else {
368 DataType::UInt64
369 };
370 mapper.with_dtype(dt)
371 },
372 ExtendConstant => mapper.with_same_dtype(),
373 }
374 }
375
376 pub(crate) fn output_name(&self) -> Option<OutputName> {
377 match self {
378 #[cfg(feature = "dtype-struct")]
379 FunctionExpr::StructExpr(StructFunction::FieldByName(name)) => {
380 Some(OutputName::Field(name.clone()))
381 },
382 _ => None,
383 }
384 }
385}
386
387pub struct FieldsMapper<'a> {
388 fields: &'a [Field],
389}
390
391impl<'a> FieldsMapper<'a> {
392 pub fn new(fields: &'a [Field]) -> Self {
393 Self { fields }
394 }
395
396 pub fn args(&self) -> &[Field] {
397 self.fields
398 }
399
400 pub fn with_same_dtype(&self) -> PolarsResult<Field> {
402 self.map_dtype(|dtype| dtype.clone())
403 }
404
405 pub fn with_dtype(&self, dtype: DataType) -> PolarsResult<Field> {
407 Ok(Field::new(self.fields[0].name().clone(), dtype))
408 }
409
410 pub fn map_dtype(&self, func: impl FnOnce(&DataType) -> DataType) -> PolarsResult<Field> {
412 let dtype = func(self.fields[0].dtype());
413 Ok(Field::new(self.fields[0].name().clone(), dtype))
414 }
415
416 pub fn get_fields_lens(&self) -> usize {
417 self.fields.len()
418 }
419
420 pub fn try_map_field(
422 &self,
423 func: impl FnOnce(&Field) -> PolarsResult<Field>,
424 ) -> PolarsResult<Field> {
425 func(&self.fields[0])
426 }
427
428 pub fn map_to_float_dtype(&self) -> PolarsResult<Field> {
430 self.map_dtype(|dtype| match dtype {
431 DataType::Float32 => DataType::Float32,
432 _ => DataType::Float64,
433 })
434 }
435
436 pub fn map_numeric_to_float_dtype(&self) -> PolarsResult<Field> {
438 self.map_dtype(|dtype| {
439 if dtype.is_primitive_numeric() {
440 match dtype {
441 DataType::Float32 => DataType::Float32,
442 _ => DataType::Float64,
443 }
444 } else {
445 dtype.clone()
446 }
447 })
448 }
449
450 pub fn to_physical_type(&self) -> PolarsResult<Field> {
452 self.map_dtype(|dtype| dtype.to_physical())
453 }
454
455 pub fn try_map_dtype(
457 &self,
458 func: impl FnOnce(&DataType) -> PolarsResult<DataType>,
459 ) -> PolarsResult<Field> {
460 let dtype = func(self.fields[0].dtype())?;
461 Ok(Field::new(self.fields[0].name().clone(), dtype))
462 }
463
464 pub fn try_map_dtypes(
466 &self,
467 func: impl FnOnce(&[&DataType]) -> PolarsResult<DataType>,
468 ) -> PolarsResult<Field> {
469 let mut fld = self.fields[0].clone();
470 let dtypes = self
471 .fields
472 .iter()
473 .map(|fld| fld.dtype())
474 .collect::<Vec<_>>();
475 let new_type = func(&dtypes)?;
476 fld.coerce(new_type);
477 Ok(fld)
478 }
479
480 pub fn map_to_supertype(&self) -> PolarsResult<Field> {
482 let st = args_to_supertype(self.fields)?;
483 let mut first = self.fields[0].clone();
484 first.coerce(st);
485 Ok(first)
486 }
487
488 pub fn map_to_list_and_array_inner_dtype(&self) -> PolarsResult<Field> {
490 let mut first = self.fields[0].clone();
491 let dt = first
492 .dtype()
493 .inner_dtype()
494 .cloned()
495 .unwrap_or_else(|| DataType::Unknown(Default::default()));
496 first.coerce(dt);
497 Ok(first)
498 }
499
500 #[cfg(feature = "dtype-array")]
501 pub fn try_map_to_array_inner_dtype(&self) -> PolarsResult<Field> {
503 let dt = self.fields[0].dtype();
504 match dt {
505 DataType::Array(_, _) => self.map_to_list_and_array_inner_dtype(),
506 _ => polars_bail!(InvalidOperation: "expected Array type, got: {}", dt),
507 }
508 }
509
510 pub fn map_to_list_supertype(&self) -> PolarsResult<Field> {
512 self.try_map_dtypes(|dts| {
513 let mut super_type_inner = None;
514
515 for dt in dts {
516 match dt {
517 DataType::List(inner) => match super_type_inner {
518 None => super_type_inner = Some(*inner.clone()),
519 Some(st_inner) => {
520 super_type_inner = Some(try_get_supertype(&st_inner, inner)?)
521 },
522 },
523 dt => match super_type_inner {
524 None => super_type_inner = Some((*dt).clone()),
525 Some(st_inner) => {
526 super_type_inner = Some(try_get_supertype(&st_inner, dt)?)
527 },
528 },
529 }
530 }
531 Ok(DataType::List(Box::new(super_type_inner.unwrap())))
532 })
533 }
534
535 #[cfg(feature = "timezones")]
537 pub fn map_datetime_dtype_timezone(&self, tz: Option<&TimeZone>) -> PolarsResult<Field> {
538 self.try_map_dtype(|dt| {
539 if let DataType::Datetime(tu, _) = dt {
540 Ok(DataType::Datetime(*tu, tz.cloned()))
541 } else {
542 polars_bail!(op = "replace-time-zone", got = dt, expected = "Datetime");
543 }
544 })
545 }
546
547 pub fn sum_dtype(&self) -> PolarsResult<Field> {
548 use DataType::*;
549 self.map_dtype(|dtype| match dtype {
550 Int8 | UInt8 | Int16 | UInt16 => Int64,
551 dt => dt.clone(),
552 })
553 }
554
555 pub fn nested_sum_type(&self) -> PolarsResult<Field> {
556 let mut first = self.fields[0].clone();
557 use DataType::*;
558 let dt = first
559 .dtype()
560 .inner_dtype()
561 .cloned()
562 .unwrap_or_else(|| Unknown(Default::default()));
563
564 match dt {
565 Boolean => first.coerce(IDX_DTYPE),
566 UInt8 | Int8 | Int16 | UInt16 => first.coerce(Int64),
567 _ => first.coerce(dt),
568 }
569 Ok(first)
570 }
571
572 pub fn nested_mean_median_type(&self) -> PolarsResult<Field> {
573 let mut first = self.fields[0].clone();
574 use DataType::*;
575 let dt = first
576 .dtype()
577 .inner_dtype()
578 .cloned()
579 .unwrap_or_else(|| Unknown(Default::default()));
580
581 let new_dt = match dt {
582 #[cfg(feature = "dtype-datetime")]
583 Date => Datetime(TimeUnit::Milliseconds, None),
584 dt if dt.is_temporal() => dt,
585 Float32 => Float32,
586 _ => Float64,
587 };
588 first.coerce(new_dt);
589 Ok(first)
590 }
591
592 pub(super) fn pow_dtype(&self) -> PolarsResult<Field> {
593 let base_dtype = self.fields[0].dtype();
594 let exponent_dtype = self.fields[1].dtype();
595 if base_dtype.is_integer() {
596 if exponent_dtype.is_float() {
597 Ok(Field::new(
598 self.fields[0].name().clone(),
599 exponent_dtype.clone(),
600 ))
601 } else {
602 Ok(Field::new(
603 self.fields[0].name().clone(),
604 base_dtype.clone(),
605 ))
606 }
607 } else {
608 Ok(Field::new(
609 self.fields[0].name().clone(),
610 base_dtype.clone(),
611 ))
612 }
613 }
614
615 #[cfg(feature = "extract_jsonpath")]
616 pub fn with_opt_dtype(&self, dtype: Option<DataType>) -> PolarsResult<Field> {
617 let dtype = dtype.unwrap_or_else(|| DataType::Unknown(Default::default()));
618 self.with_dtype(dtype)
619 }
620
621 #[cfg(feature = "replace")]
622 pub fn replace_dtype(&self, return_dtype: Option<DataType>) -> PolarsResult<Field> {
623 let dtype = match return_dtype {
624 Some(dtype) => dtype,
625 None => {
626 let new = &self.fields[2];
627 let default = self.fields.get(3);
628
629 let inner_dtype = new.dtype().inner_dtype().unwrap_or(new.dtype());
631
632 match default {
633 Some(default) => try_get_supertype(default.dtype(), inner_dtype)?,
634 None => inner_dtype.clone(),
635 }
636 },
637 };
638 self.with_dtype(dtype)
639 }
640}
641
642pub(crate) fn args_to_supertype<D: AsRef<DataType>>(dtypes: &[D]) -> PolarsResult<DataType> {
643 let mut st = dtypes[0].as_ref().clone();
644 for dt in &dtypes[1..] {
645 st = try_get_supertype(&st, dt.as_ref())?
646 }
647
648 match (dtypes[0].as_ref(), &st) {
649 #[cfg(feature = "dtype-categorical")]
650 (DataType::Categorical(_, ord), DataType::String) => st = DataType::Categorical(None, *ord),
651 _ => {
652 if let DataType::Unknown(kind) = st {
653 match kind {
654 UnknownKind::Float => st = DataType::Float64,
655 UnknownKind::Int(v) => {
656 st = materialize_dyn_int(v).dtype();
657 },
658 UnknownKind::Str => st = DataType::String,
659 _ => {},
660 }
661 }
662 },
663 }
664
665 Ok(st)
666}