1use polars_core::utils::materialize_dyn_int;
2
3use super::*;
4
5impl FunctionExpr {
6 pub(crate) fn get_field(
7 &self,
8 _input_schema: &Schema,
9 _cntxt: Context,
10 fields: &[Field],
11 ) -> PolarsResult<Field> {
12 use FunctionExpr::*;
13
14 let mapper = FieldsMapper { fields };
15 match self {
16 #[cfg(feature = "dtype-array")]
18 ArrayExpr(func) => func.get_field(mapper),
19 BinaryExpr(s) => s.get_field(mapper),
20 #[cfg(feature = "dtype-categorical")]
21 Categorical(func) => func.get_field(mapper),
22 ListExpr(func) => func.get_field(mapper),
23 #[cfg(feature = "strings")]
24 StringExpr(s) => s.get_field(mapper),
25 #[cfg(feature = "dtype-struct")]
26 StructExpr(s) => s.get_field(mapper),
27 #[cfg(feature = "temporal")]
28 TemporalExpr(fun) => fun.get_field(mapper),
29 #[cfg(feature = "bitwise")]
30 Bitwise(fun) => fun.get_field(mapper),
31
32 Boolean(func) => func.get_field(mapper),
34 #[cfg(feature = "business")]
35 Business(func) => match func {
36 BusinessFunction::BusinessDayCount { .. } => mapper.with_dtype(DataType::Int32),
37 BusinessFunction::AddBusinessDay { .. } => mapper.with_same_dtype(),
38 },
39 #[cfg(feature = "abs")]
40 Abs => mapper.with_same_dtype(),
41 Negate => mapper.with_same_dtype(),
42 NullCount => mapper.with_dtype(IDX_DTYPE),
43 Pow(pow_function) => match pow_function {
44 PowFunction::Generic => mapper.pow_dtype(),
45 _ => mapper.map_to_float_dtype(),
46 },
47 Coalesce => mapper.map_to_supertype(),
48 #[cfg(feature = "row_hash")]
49 Hash(..) => mapper.with_dtype(DataType::UInt64),
50 #[cfg(feature = "arg_where")]
51 ArgWhere => mapper.with_dtype(IDX_DTYPE),
52 #[cfg(feature = "index_of")]
53 IndexOf => mapper.with_dtype(IDX_DTYPE),
54 #[cfg(feature = "search_sorted")]
55 SearchSorted(_) => mapper.with_dtype(IDX_DTYPE),
56 #[cfg(feature = "range")]
57 Range(func) => func.get_field(mapper),
58 #[cfg(feature = "trigonometry")]
59 Trigonometry(_) => mapper.map_to_float_dtype(),
60 #[cfg(feature = "trigonometry")]
61 Atan2 => mapper.map_to_float_dtype(),
62 #[cfg(feature = "sign")]
63 Sign => mapper.with_dtype(DataType::Int64),
64 FillNull { .. } => mapper.map_to_supertype(),
65 #[cfg(feature = "rolling_window")]
66 RollingExpr(rolling_func, ..) => {
67 use RollingFunction::*;
68 match rolling_func {
69 Min(_) | Max(_) | Sum(_) => mapper.with_same_dtype(),
70 Mean(_) | Quantile(_) | Var(_) | Std(_) => mapper.map_to_float_dtype(),
71 #[cfg(feature = "cov")]
72 CorrCov {..} => mapper.map_to_float_dtype(),
73 #[cfg(feature = "moment")]
74 Skew(..) => mapper.map_to_float_dtype(),
75 }
76 },
77 #[cfg(feature = "rolling_window_by")]
78 RollingExprBy(rolling_func, ..) => {
79 use RollingFunctionBy::*;
80 match rolling_func {
81 MinBy(_) | MaxBy(_) | SumBy(_) => mapper.with_same_dtype(),
82 MeanBy(_) | QuantileBy(_) | VarBy(_) | StdBy(_) => mapper.map_to_float_dtype(),
83 }
84 },
85 ShiftAndFill => mapper.with_same_dtype(),
86 DropNans => mapper.with_same_dtype(),
87 DropNulls => mapper.with_same_dtype(),
88 #[cfg(feature = "round_series")]
89 Clip { .. } => mapper.with_same_dtype(),
90 #[cfg(feature = "mode")]
91 Mode => mapper.with_same_dtype(),
92 #[cfg(feature = "moment")]
93 Skew(_) => mapper.with_dtype(DataType::Float64),
94 #[cfg(feature = "moment")]
95 Kurtosis(..) => mapper.with_dtype(DataType::Float64),
96 ArgUnique => mapper.with_dtype(IDX_DTYPE),
97 Repeat => mapper.with_same_dtype(),
98 #[cfg(feature = "rank")]
99 Rank { options, .. } => mapper.with_dtype(match options.method {
100 RankMethod::Average => DataType::Float64,
101 _ => IDX_DTYPE,
102 }),
103 #[cfg(feature = "dtype-struct")]
104 AsStruct => Ok(Field::new(
105 fields[0].name().clone(),
106 DataType::Struct(fields.to_vec()),
107 )),
108 #[cfg(feature = "top_k")]
109 TopK { .. } => mapper.with_same_dtype(),
110 #[cfg(feature = "top_k")]
111 TopKBy { .. } => mapper.with_same_dtype(),
112 #[cfg(feature = "dtype-struct")]
113 ValueCounts {
114 sort: _,
115 parallel: _,
116 name,
117 normalize,
118 } => mapper.map_dtype(|dt| {
119 let count_dt = if *normalize {
120 DataType::Float64
121 } else {
122 IDX_DTYPE
123 };
124 DataType::Struct(vec![
125 Field::new(fields[0].name().clone(), dt.clone()),
126 Field::new(name.clone(), count_dt),
127 ])
128 }),
129 #[cfg(feature = "unique_counts")]
130 UniqueCounts => mapper.with_dtype(IDX_DTYPE),
131 Shift | Reverse => mapper.with_same_dtype(),
132 #[cfg(feature = "cum_agg")]
133 CumCount { .. } => mapper.with_dtype(IDX_DTYPE),
134 #[cfg(feature = "cum_agg")]
135 CumSum { .. } => mapper.map_dtype(cum::dtypes::cum_sum),
136 #[cfg(feature = "cum_agg")]
137 CumProd { .. } => mapper.map_dtype(cum::dtypes::cum_prod),
138 #[cfg(feature = "cum_agg")]
139 CumMin { .. } => mapper.with_same_dtype(),
140 #[cfg(feature = "cum_agg")]
141 CumMax { .. } => mapper.with_same_dtype(),
142 #[cfg(feature = "approx_unique")]
143 ApproxNUnique => mapper.with_dtype(IDX_DTYPE),
144 #[cfg(feature = "hist")]
145 Hist {
146 include_category,
147 include_breakpoint,
148 ..
149 } => {
150 if *include_breakpoint || *include_category {
151 let mut fields = Vec::with_capacity(3);
152 if *include_breakpoint {
153 fields.push(Field::new(
154 PlSmallStr::from_static("breakpoint"),
155 DataType::Float64,
156 ));
157 }
158 if *include_category {
159 fields.push(Field::new(
160 PlSmallStr::from_static("category"),
161 DataType::Categorical(None, Default::default()),
162 ));
163 }
164 fields.push(Field::new(PlSmallStr::from_static("count"), IDX_DTYPE));
165 mapper.with_dtype(DataType::Struct(fields))
166 } else {
167 mapper.with_dtype(IDX_DTYPE)
168 }
169 },
170 #[cfg(feature = "diff")]
171 Diff(_, _) => mapper.map_dtype(|dt| match dt {
172 #[cfg(feature = "dtype-datetime")]
173 DataType::Datetime(tu, _) => DataType::Duration(*tu),
174 #[cfg(feature = "dtype-date")]
175 DataType::Date => DataType::Duration(TimeUnit::Milliseconds),
176 #[cfg(feature = "dtype-time")]
177 DataType::Time => DataType::Duration(TimeUnit::Nanoseconds),
178 DataType::UInt64 | DataType::UInt32 => DataType::Int64,
179 DataType::UInt16 => DataType::Int32,
180 DataType::UInt8 => DataType::Int16,
181 dt => dt.clone(),
182 }),
183 #[cfg(feature = "pct_change")]
184 PctChange => mapper.map_dtype(|dt| match dt {
185 DataType::Float64 | DataType::Float32 => dt.clone(),
186 _ => DataType::Float64,
187 }),
188 #[cfg(feature = "interpolate")]
189 Interpolate(method) => match method {
190 InterpolationMethod::Linear => mapper.map_numeric_to_float_dtype(),
191 InterpolationMethod::Nearest => mapper.with_same_dtype(),
192 },
193 #[cfg(feature = "interpolate_by")]
194 InterpolateBy => mapper.map_numeric_to_float_dtype(),
195 ShrinkType => {
196 mapper.map_dtype(|dt| {
205 if dt.is_primitive_numeric() {
206 if dt.is_float() {
207 DataType::Float32
208 } else if dt.is_unsigned_integer() {
209 DataType::Int8
210 } else {
211 DataType::UInt8
212 }
213 } else {
214 dt.clone()
215 }
216 })
217 },
218 #[cfg(feature = "log")]
219 Entropy { .. } | Log { .. } | Log1p | Exp => mapper.map_to_float_dtype(),
220 Unique(_) => mapper.with_same_dtype(),
221 #[cfg(feature = "round_series")]
222 Round { .. } | RoundSF { .. } | Floor | Ceil => mapper.with_same_dtype(),
223 UpperBound | LowerBound => mapper.with_same_dtype(),
224 #[cfg(feature = "fused")]
225 Fused(_) => mapper.map_to_supertype(),
226 ConcatExpr(_) => mapper.map_to_supertype(),
227 #[cfg(feature = "cov")]
228 Correlation { .. } => mapper.map_to_float_dtype(),
229 #[cfg(feature = "peaks")]
230 PeakMin => mapper.with_same_dtype(),
231 #[cfg(feature = "peaks")]
232 PeakMax => mapper.with_same_dtype(),
233 #[cfg(feature = "cutqcut")]
234 Cut {
235 include_breaks: false,
236 ..
237 } => mapper.with_dtype(DataType::Categorical(None, Default::default())),
238 #[cfg(feature = "cutqcut")]
239 Cut {
240 include_breaks: true,
241 ..
242 } => {
243 let struct_dt = DataType::Struct(vec![
244 Field::new(PlSmallStr::from_static("breakpoint"), DataType::Float64),
245 Field::new(
246 PlSmallStr::from_static("category"),
247 DataType::Categorical(None, Default::default()),
248 ),
249 ]);
250 mapper.with_dtype(struct_dt)
251 },
252 #[cfg(feature = "repeat_by")]
253 RepeatBy => mapper.map_dtype(|dt| DataType::List(dt.clone().into())),
254 #[cfg(feature = "dtype-array")]
255 Reshape(dims) => mapper.try_map_dtype(|dt: &DataType| {
256 let dtype = dt.inner_dtype().unwrap_or(dt).clone();
257
258 if dims.len() == 1 {
259 return Ok(dtype);
260 }
261
262 let num_infers = dims.iter().filter(|d| matches!(d, ReshapeDimension::Infer)).count();
263
264 polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension");
265
266 let mut inferred_size = 0;
267 if num_infers == 1 {
268 let mut total_size = 1u64;
269 let mut current = dt;
270 while let DataType::Array(dt, width) = current {
271 if *width == 0 {
272 total_size = 0;
273 break;
274 }
275
276 current = dt.as_ref();
277 total_size *= *width as u64;
278 }
279
280 let current_size = dims.iter().map(|d| d.get_or_infer(1)).product::<u64>();
281 inferred_size = total_size / current_size;
282 }
283
284 let mut prev_dtype = dtype.leaf_dtype().clone();
285
286 for dim in &dims[1..] {
288 prev_dtype = DataType::Array(Box::new(prev_dtype), dim.get_or_infer(inferred_size) as usize);
289 }
290 Ok(prev_dtype)
291 }),
292 #[cfg(feature = "cutqcut")]
293 QCut {
294 include_breaks: false,
295 ..
296 } => mapper.with_dtype(DataType::Categorical(None, Default::default())),
297 #[cfg(feature = "cutqcut")]
298 QCut {
299 include_breaks: true,
300 ..
301 } => {
302 let struct_dt = DataType::Struct(vec![
303 Field::new(PlSmallStr::from_static("breakpoint"), DataType::Float64),
304 Field::new(
305 PlSmallStr::from_static("category"),
306 DataType::Categorical(None, Default::default()),
307 ),
308 ]);
309 mapper.with_dtype(struct_dt)
310 },
311 #[cfg(feature = "rle")]
312 RLE => mapper.map_dtype(|dt| {
313 DataType::Struct(vec![
314 Field::new(PlSmallStr::from_static("len"), IDX_DTYPE),
315 Field::new(PlSmallStr::from_static("value"), dt.clone()),
316 ])
317 }),
318 #[cfg(feature = "rle")]
319 RLEID => mapper.with_dtype(IDX_DTYPE),
320 ToPhysical => mapper.to_physical_type(),
321 #[cfg(feature = "random")]
322 Random { .. } => mapper.with_same_dtype(),
323 SetSortedFlag(_) => mapper.with_same_dtype(),
324 #[cfg(feature = "ffi_plugin")]
325 FfiPlugin {
326 lib,
327 symbol,
328 kwargs,
329 } => unsafe { plugin::plugin_field(fields, lib, symbol.as_ref(), kwargs) },
330 BackwardFill { .. } => mapper.with_same_dtype(),
331 ForwardFill { .. } => mapper.with_same_dtype(),
332 MaxHorizontal => mapper.map_to_supertype(),
333 MinHorizontal => mapper.map_to_supertype(),
334 SumHorizontal { .. } => {
335 mapper.map_to_supertype().map(|mut f| {
336 if f.dtype == DataType::Boolean {
337 f.dtype = IDX_DTYPE;
338 }
339 f
340 })
341 },
342 MeanHorizontal { .. } => {
343 mapper.map_to_supertype().map(|mut f| {
344 match f.dtype {
345 dt @ DataType::Float32 => { f.dtype = dt; },
346 _ => { f.dtype = DataType::Float64; },
347 };
348 f
349 })
350 }
351 #[cfg(feature = "ewma")]
352 EwmMean { .. } => mapper.map_to_float_dtype(),
353 #[cfg(feature = "ewma_by")]
354 EwmMeanBy { .. } => mapper.map_to_float_dtype(),
355 #[cfg(feature = "ewma")]
356 EwmStd { .. } => mapper.map_to_float_dtype(),
357 #[cfg(feature = "ewma")]
358 EwmVar { .. } => mapper.map_to_float_dtype(),
359 #[cfg(feature = "replace")]
360 Replace => mapper.with_same_dtype(),
361 #[cfg(feature = "replace")]
362 ReplaceStrict { return_dtype } => mapper.replace_dtype(return_dtype.clone()),
363 FillNullWithStrategy(_) => mapper.with_same_dtype(),
364 GatherEvery { .. } => mapper.with_same_dtype(),
365 #[cfg(feature = "reinterpret")]
366 Reinterpret(signed) => {
367 let dt = if *signed {
368 DataType::Int64
369 } else {
370 DataType::UInt64
371 };
372 mapper.with_dtype(dt)
373 },
374 ExtendConstant => mapper.with_same_dtype(),
375 }
376 }
377
378 pub(crate) fn output_name(&self) -> Option<OutputName> {
379 match self {
380 #[cfg(feature = "dtype-struct")]
381 FunctionExpr::StructExpr(StructFunction::FieldByName(name)) => {
382 Some(OutputName::Field(name.clone()))
383 },
384 _ => None,
385 }
386 }
387}
388
389pub struct FieldsMapper<'a> {
390 fields: &'a [Field],
391}
392
393impl<'a> FieldsMapper<'a> {
394 pub fn new(fields: &'a [Field]) -> Self {
395 Self { fields }
396 }
397
398 pub fn args(&self) -> &[Field] {
399 self.fields
400 }
401
402 pub fn with_same_dtype(&self) -> PolarsResult<Field> {
404 self.map_dtype(|dtype| dtype.clone())
405 }
406
407 pub fn with_dtype(&self, dtype: DataType) -> PolarsResult<Field> {
409 Ok(Field::new(self.fields[0].name().clone(), dtype))
410 }
411
412 pub fn map_dtype(&self, func: impl FnOnce(&DataType) -> DataType) -> PolarsResult<Field> {
414 let dtype = func(self.fields[0].dtype());
415 Ok(Field::new(self.fields[0].name().clone(), dtype))
416 }
417
418 pub fn get_fields_lens(&self) -> usize {
419 self.fields.len()
420 }
421
422 pub fn try_map_field(
424 &self,
425 func: impl FnOnce(&Field) -> PolarsResult<Field>,
426 ) -> PolarsResult<Field> {
427 func(&self.fields[0])
428 }
429
430 pub fn map_to_float_dtype(&self) -> PolarsResult<Field> {
432 self.map_dtype(|dtype| match dtype {
433 DataType::Float32 => DataType::Float32,
434 _ => DataType::Float64,
435 })
436 }
437
438 pub fn map_numeric_to_float_dtype(&self) -> PolarsResult<Field> {
440 self.map_dtype(|dtype| {
441 if dtype.is_primitive_numeric() {
442 match dtype {
443 DataType::Float32 => DataType::Float32,
444 _ => DataType::Float64,
445 }
446 } else {
447 dtype.clone()
448 }
449 })
450 }
451
452 pub fn to_physical_type(&self) -> PolarsResult<Field> {
454 self.map_dtype(|dtype| dtype.to_physical())
455 }
456
457 pub fn try_map_dtype(
459 &self,
460 func: impl FnOnce(&DataType) -> PolarsResult<DataType>,
461 ) -> PolarsResult<Field> {
462 let dtype = func(self.fields[0].dtype())?;
463 Ok(Field::new(self.fields[0].name().clone(), dtype))
464 }
465
466 pub fn try_map_dtypes(
468 &self,
469 func: impl FnOnce(&[&DataType]) -> PolarsResult<DataType>,
470 ) -> PolarsResult<Field> {
471 let mut fld = self.fields[0].clone();
472 let dtypes = self
473 .fields
474 .iter()
475 .map(|fld| fld.dtype())
476 .collect::<Vec<_>>();
477 let new_type = func(&dtypes)?;
478 fld.coerce(new_type);
479 Ok(fld)
480 }
481
482 pub fn map_to_supertype(&self) -> PolarsResult<Field> {
484 let st = args_to_supertype(self.fields)?;
485 let mut first = self.fields[0].clone();
486 first.coerce(st);
487 Ok(first)
488 }
489
490 pub fn map_to_list_and_array_inner_dtype(&self) -> PolarsResult<Field> {
492 let mut first = self.fields[0].clone();
493 let dt = first
494 .dtype()
495 .inner_dtype()
496 .cloned()
497 .unwrap_or_else(|| DataType::Unknown(Default::default()));
498 first.coerce(dt);
499 Ok(first)
500 }
501
502 #[cfg(feature = "dtype-array")]
503 pub fn try_map_to_array_inner_dtype(&self) -> PolarsResult<Field> {
505 let dt = self.fields[0].dtype();
506 match dt {
507 DataType::Array(_, _) => self.map_to_list_and_array_inner_dtype(),
508 _ => polars_bail!(InvalidOperation: "expected Array type, got: {}", dt),
509 }
510 }
511
512 pub fn map_to_list_supertype(&self) -> PolarsResult<Field> {
514 self.try_map_dtypes(|dts| {
515 let mut super_type_inner = None;
516
517 for dt in dts {
518 match dt {
519 DataType::List(inner) => match super_type_inner {
520 None => super_type_inner = Some(*inner.clone()),
521 Some(st_inner) => {
522 super_type_inner = Some(try_get_supertype(&st_inner, inner)?)
523 },
524 },
525 dt => match super_type_inner {
526 None => super_type_inner = Some((*dt).clone()),
527 Some(st_inner) => {
528 super_type_inner = Some(try_get_supertype(&st_inner, dt)?)
529 },
530 },
531 }
532 }
533 Ok(DataType::List(Box::new(super_type_inner.unwrap())))
534 })
535 }
536
537 #[cfg(feature = "timezones")]
539 pub fn map_datetime_dtype_timezone(&self, tz: Option<&TimeZone>) -> PolarsResult<Field> {
540 self.try_map_dtype(|dt| {
541 if let DataType::Datetime(tu, _) = dt {
542 Ok(DataType::Datetime(*tu, tz.cloned()))
543 } else {
544 polars_bail!(op = "replace-time-zone", got = dt, expected = "Datetime");
545 }
546 })
547 }
548
549 pub fn nested_sum_type(&self) -> PolarsResult<Field> {
550 let mut first = self.fields[0].clone();
551 use DataType::*;
552 let dt = first
553 .dtype()
554 .inner_dtype()
555 .cloned()
556 .unwrap_or_else(|| Unknown(Default::default()));
557
558 match dt {
559 Boolean => first.coerce(IDX_DTYPE),
560 UInt8 | Int8 | Int16 | UInt16 => first.coerce(Int64),
561 _ => first.coerce(dt),
562 }
563 Ok(first)
564 }
565
566 pub(super) fn pow_dtype(&self) -> PolarsResult<Field> {
567 let base_dtype = self.fields[0].dtype();
568 let exponent_dtype = self.fields[1].dtype();
569 if base_dtype.is_integer() {
570 if exponent_dtype.is_float() {
571 Ok(Field::new(
572 self.fields[0].name().clone(),
573 exponent_dtype.clone(),
574 ))
575 } else {
576 Ok(Field::new(
577 self.fields[0].name().clone(),
578 base_dtype.clone(),
579 ))
580 }
581 } else {
582 Ok(Field::new(
583 self.fields[0].name().clone(),
584 base_dtype.clone(),
585 ))
586 }
587 }
588
589 #[cfg(feature = "extract_jsonpath")]
590 pub fn with_opt_dtype(&self, dtype: Option<DataType>) -> PolarsResult<Field> {
591 let dtype = dtype.unwrap_or_else(|| DataType::Unknown(Default::default()));
592 self.with_dtype(dtype)
593 }
594
595 #[cfg(feature = "replace")]
596 pub fn replace_dtype(&self, return_dtype: Option<DataType>) -> PolarsResult<Field> {
597 let dtype = match return_dtype {
598 Some(dtype) => dtype,
599 None => {
600 let new = &self.fields[2];
601 let default = self.fields.get(3);
602 match default {
603 Some(default) => try_get_supertype(default.dtype(), new.dtype())?,
604 None => new.dtype().clone(),
605 }
606 },
607 };
608 self.with_dtype(dtype)
609 }
610}
611
612pub(crate) fn args_to_supertype<D: AsRef<DataType>>(dtypes: &[D]) -> PolarsResult<DataType> {
613 let mut st = dtypes[0].as_ref().clone();
614 for dt in &dtypes[1..] {
615 st = try_get_supertype(&st, dt.as_ref())?
616 }
617
618 match (dtypes[0].as_ref(), &st) {
619 #[cfg(feature = "dtype-categorical")]
620 (DataType::Categorical(_, ord), DataType::String) => st = DataType::Categorical(None, *ord),
621 _ => {
622 if let DataType::Unknown(kind) = st {
623 match kind {
624 UnknownKind::Float => st = DataType::Float64,
625 UnknownKind::Int(v) => {
626 st = materialize_dyn_int(v).dtype();
627 },
628 UnknownKind::Str => st = DataType::String,
629 _ => {},
630 }
631 }
632 },
633 }
634
635 Ok(st)
636}