use std::fmt::{Display, Formatter};
use polars_core::prelude::*;
use polars_ops::series::ClosedInterval;
#[cfg(feature = "temporal")]
use polars_time::{ClosedWindow, Duration};
use super::{FunctionOptions, IRFunctionExpr};
#[cfg(any(feature = "dtype-date", feature = "dtype-datetime"))]
use crate::dsl::function_expr::DateRangeArgs;
use crate::plans::aexpr::function_expr::FieldsMapper;
use crate::prelude::FunctionFlags;
#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Clone, PartialEq, Debug, Eq, Hash)]
pub enum IRRangeFunction {
IntRange {
step: i64,
dtype: DataType,
},
IntRanges {
dtype: DataType,
},
LinearSpace {
closed: ClosedInterval,
},
LinearSpaces {
closed: ClosedInterval,
array_width: Option<usize>,
},
#[cfg(feature = "dtype-date")]
DateRange {
interval: Option<Duration>,
closed: ClosedWindow,
arg_type: DateRangeArgs,
},
#[cfg(feature = "dtype-date")]
DateRanges {
interval: Option<Duration>,
closed: ClosedWindow,
arg_type: DateRangeArgs,
},
#[cfg(feature = "dtype-datetime")]
DatetimeRange {
interval: Option<Duration>,
closed: ClosedWindow,
time_unit: Option<TimeUnit>,
time_zone: Option<TimeZone>,
arg_type: DateRangeArgs,
},
#[cfg(feature = "dtype-datetime")]
DatetimeRanges {
interval: Option<Duration>,
closed: ClosedWindow,
time_unit: Option<TimeUnit>,
time_zone: Option<TimeZone>,
arg_type: DateRangeArgs,
},
#[cfg(feature = "dtype-time")]
TimeRange {
interval: Duration,
closed: ClosedWindow,
},
#[cfg(feature = "dtype-time")]
TimeRanges {
interval: Duration,
closed: ClosedWindow,
},
}
fn map_linspace_dtype(mapper: &FieldsMapper) -> PolarsResult<DataType> {
let fields = mapper.args();
let start_dtype = fields[0].dtype();
let end_dtype = fields[1].dtype();
Ok(match (start_dtype, end_dtype) {
#[cfg(feature = "dtype-f16")]
(&DataType::Float16, &DataType::Float16) => DataType::Float16,
(&DataType::Float32, &DataType::Float32) => DataType::Float32,
(dt1, dt2) if dt1.is_temporal() && dt1 == dt2 => {
if dt1 == &DataType::Date {
DataType::Datetime(TimeUnit::Microseconds, None)
} else {
dt1.clone()
}
},
(dt1, dt2) if !dt1.is_primitive_numeric() || !dt2.is_primitive_numeric() => {
polars_bail!(ComputeError:
"'start' and 'end' have incompatible dtypes, got {:?} and {:?}",
dt1, dt2
)
},
_ => DataType::Float64,
})
}
impl IRRangeFunction {
pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
use IRRangeFunction::*;
match self {
IntRange { dtype, .. } => mapper.with_dtype(dtype.clone()),
IntRanges { dtype } => mapper.with_dtype(DataType::List(Box::new(dtype.clone()))),
LinearSpace { .. } => mapper.with_dtype(map_linspace_dtype(&mapper)?),
LinearSpaces {
closed: _,
array_width,
} => {
let inner = Box::new(map_linspace_dtype(&mapper)?);
let dt = match array_width {
Some(width) => DataType::Array(inner, *width),
None => DataType::List(inner),
};
mapper.with_dtype(dt)
},
#[cfg(feature = "dtype-date")]
DateRange {
interval: _,
closed: _,
arg_type: _,
} => mapper.with_dtype(DataType::Date),
#[cfg(feature = "dtype-date")]
DateRanges {
interval: _,
closed: _,
arg_type: _,
} => mapper.with_dtype(DataType::List(Box::new(DataType::Date))),
#[cfg(feature = "dtype-datetime")]
DatetimeRange {
interval: _,
closed: _,
time_unit,
time_zone,
arg_type: _,
} => {
let dtype =
mapper.map_to_datetime_range_dtype(time_unit.as_ref(), time_zone.as_ref())?;
mapper.with_dtype(dtype)
},
#[cfg(feature = "dtype-datetime")]
DatetimeRanges {
interval: _,
closed: _,
time_unit,
time_zone,
arg_type: _,
} => {
let inner_dtype =
mapper.map_to_datetime_range_dtype(time_unit.as_ref(), time_zone.as_ref())?;
mapper.with_dtype(DataType::List(Box::new(inner_dtype)))
},
#[cfg(feature = "dtype-time")]
TimeRange { .. } => mapper.with_dtype(DataType::Time),
#[cfg(feature = "dtype-time")]
TimeRanges { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::Time))),
}
}
pub fn function_options(&self) -> FunctionOptions {
use IRRangeFunction as R;
match self {
R::IntRange { .. } => {
FunctionOptions::groupwise().with_flags(|f| f | FunctionFlags::ALLOW_RENAME)
},
R::IntRanges { .. } => {
FunctionOptions::elementwise().with_flags(|f| f | FunctionFlags::ALLOW_RENAME)
},
R::LinearSpace { .. } => {
FunctionOptions::groupwise().with_flags(|f| f | FunctionFlags::ALLOW_RENAME)
},
R::LinearSpaces { .. } => {
FunctionOptions::elementwise().with_flags(|f| f | FunctionFlags::ALLOW_RENAME)
},
#[cfg(feature = "dtype-date")]
R::DateRange { .. } => {
FunctionOptions::groupwise().with_flags(|f| f | FunctionFlags::ALLOW_RENAME)
},
#[cfg(feature = "dtype-date")]
R::DateRanges { .. } => {
FunctionOptions::elementwise().with_flags(|f| f | FunctionFlags::ALLOW_RENAME)
},
#[cfg(feature = "dtype-datetime")]
R::DatetimeRange { .. } => {
FunctionOptions::groupwise().with_flags(|f| f | FunctionFlags::ALLOW_RENAME)
},
#[cfg(feature = "dtype-datetime")]
R::DatetimeRanges { .. } => {
FunctionOptions::elementwise().with_flags(|f| f | FunctionFlags::ALLOW_RENAME)
},
#[cfg(feature = "dtype-time")]
R::TimeRange { .. } => {
FunctionOptions::groupwise().with_flags(|f| f | FunctionFlags::ALLOW_RENAME)
},
#[cfg(feature = "dtype-time")]
R::TimeRanges { .. } => {
FunctionOptions::elementwise().with_flags(|f| f | FunctionFlags::ALLOW_RENAME)
},
}
}
}
impl Display for IRRangeFunction {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use IRRangeFunction::*;
let s = match self {
IntRange { .. } => "int_range",
IntRanges { .. } => "int_ranges",
LinearSpace { .. } => "linear_space",
LinearSpaces { .. } => "linear_spaces",
#[cfg(feature = "dtype-date")]
DateRange { .. } => "date_range",
#[cfg(feature = "dtype-date")]
DateRanges { .. } => "date_ranges",
#[cfg(feature = "dtype-datetime")]
DatetimeRange { .. } => "datetime_range",
#[cfg(feature = "dtype-datetime")]
DatetimeRanges { .. } => "datetime_ranges",
#[cfg(feature = "dtype-time")]
TimeRange { .. } => "time_range",
#[cfg(feature = "dtype-time")]
TimeRanges { .. } => "time_ranges",
};
write!(f, "{s}")
}
}
impl From<IRRangeFunction> for IRFunctionExpr {
fn from(value: IRRangeFunction) -> Self {
Self::Range(value)
}
}
impl FieldsMapper<'_> {
pub fn map_to_datetime_range_dtype(
&self,
time_unit: Option<&TimeUnit>,
time_zone: Option<&TimeZone>,
) -> PolarsResult<DataType> {
let data_dtype = self.map_to_supertype()?.dtype;
let (data_tu, data_tz) = if let DataType::Datetime(tu, tz) = data_dtype {
(tu, tz)
} else {
(TimeUnit::Microseconds, None)
};
let tu = match time_unit {
Some(tu) => *tu,
None => data_tu,
};
let tz = time_zone.cloned().or(data_tz);
Ok(DataType::Datetime(tu, tz))
}
}