use polars_core::POOL;
use polars_core::prelude::*;
use polars_plan::prelude::*;
use super::*;
use crate::expressions::{AggregationContext, PhysicalExpr};
pub struct TernaryExpr {
predicate: Arc<dyn PhysicalExpr>,
truthy: Arc<dyn PhysicalExpr>,
falsy: Arc<dyn PhysicalExpr>,
expr: Expr,
run_par: bool,
returns_scalar: bool,
}
impl TernaryExpr {
pub fn new(
predicate: Arc<dyn PhysicalExpr>,
truthy: Arc<dyn PhysicalExpr>,
falsy: Arc<dyn PhysicalExpr>,
expr: Expr,
run_par: bool,
returns_scalar: bool,
) -> Self {
Self {
predicate,
truthy,
falsy,
expr,
run_par,
returns_scalar,
}
}
}
fn finish_as_iters<'a>(
mut ac_truthy: AggregationContext<'a>,
mut ac_falsy: AggregationContext<'a>,
mut ac_mask: AggregationContext<'a>,
) -> PolarsResult<AggregationContext<'a>> {
let ca = ac_truthy
.iter_groups(false)
.zip(ac_falsy.iter_groups(false))
.zip(ac_mask.iter_groups(false))
.map(|((truthy, falsy), mask)| {
match (truthy, falsy, mask) {
(Some(truthy), Some(falsy), Some(mask)) => Some(
truthy
.as_ref()
.zip_with(mask.as_ref().bool()?, falsy.as_ref()),
),
_ => None,
}
.transpose()
})
.collect::<PolarsResult<ListChunked>>()?
.with_name(ac_truthy.get_values().name().clone());
let arr = ca.downcast_iter().next().unwrap();
let list_vals_len = arr.values().len();
let mut out = ca.into_column();
if ac_truthy.arity_should_explode() && ac_falsy.arity_should_explode() && ac_mask.arity_should_explode() &&
list_vals_len == ac_truthy.groups.len()
{
out = out.explode(ExplodeOptions {
empty_as_null: true,
keep_nulls: true,
})?
}
ac_truthy.with_agg_state(AggState::AggregatedList(out));
ac_truthy.with_update_groups(UpdateGroups::WithSeriesLen);
Ok(ac_truthy)
}
impl PhysicalExpr for TernaryExpr {
fn as_expression(&self) -> Option<&Expr> {
Some(&self.expr)
}
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
let mut state = state.split();
state.remove_cache_window_flag();
let mask_series = self.predicate.evaluate(df, &state)?;
let mask = mask_series.bool()?.clone();
let op_truthy = || self.truthy.evaluate(df, &state);
let op_falsy = || self.falsy.evaluate(df, &state);
let (truthy, falsy) = if self.run_par {
POOL.install(|| rayon::join(op_truthy, op_falsy))
} else {
(op_truthy(), op_falsy())
};
let truthy = truthy?;
let falsy = falsy?;
truthy.zip_with(&mask, &falsy)
}
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.truthy.to_field(input_schema)
}
#[allow(clippy::ptr_arg)]
fn evaluate_on_groups<'a>(
&self,
df: &DataFrame,
groups: &'a GroupPositions,
state: &ExecutionState,
) -> PolarsResult<AggregationContext<'a>> {
let op_mask = || self.predicate.evaluate_on_groups(df, groups, state);
let op_truthy = || self.truthy.evaluate_on_groups(df, groups, state);
let op_falsy = || self.falsy.evaluate_on_groups(df, groups, state);
let (ac_mask, (ac_truthy, ac_falsy)) = if self.run_par {
POOL.install(|| rayon::join(op_mask, || rayon::join(op_truthy, op_falsy)))
} else {
(op_mask(), (op_truthy(), op_falsy()))
};
let mut ac_mask = ac_mask?;
let mut ac_truthy = ac_truthy?;
let mut ac_falsy = ac_falsy?;
use AggState::*;
let mut has_non_unit_literal = false;
let mut has_aggregated = false;
let mut non_aggregated_len_modified = false;
for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() {
match ac.agg_state() {
LiteralScalar(s) => {
has_non_unit_literal = s.len() != 1;
if has_non_unit_literal {
break;
}
},
NotAggregated(_) => {
non_aggregated_len_modified |= !ac.original_len;
},
AggregatedScalar(_) | AggregatedList(_) => {
has_aggregated = true;
},
}
}
if has_non_unit_literal {
if state.verbose() {
eprintln!("ternary agg: finish as iters due to non-unit literal")
}
return finish_as_iters(ac_truthy, ac_falsy, ac_mask);
}
if !has_aggregated && !non_aggregated_len_modified {
if state.verbose() {
eprintln!("ternary agg: finish all not-aggregated or unit literal");
}
let out = ac_truthy
.get_values()
.zip_with(ac_mask.get_values().bool()?, ac_falsy.get_values())?;
for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() {
if matches!(ac.agg_state(), NotAggregated(_)) {
let ac_target = ac;
return Ok(AggregationContext {
state: NotAggregated(out),
groups: ac_target.groups.clone(),
update_groups: ac_target.update_groups,
original_len: ac_target.original_len,
});
}
}
ac_truthy.with_agg_state(LiteralScalar(out));
return Ok(ac_truthy);
}
for ac in [&mut ac_mask, &mut ac_truthy, &mut ac_falsy].into_iter() {
if matches!(ac.agg_state(), NotAggregated(_)) {
let _ = ac.aggregated();
}
}
let mut non_literal_acs = Vec::<&AggregationContext>::with_capacity(3);
for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() {
if !matches!(ac.agg_state(), LiteralScalar(_)) {
non_literal_acs.push(ac);
}
}
for (ac_l, ac_r) in non_literal_acs.iter().zip(non_literal_acs.iter().skip(1)) {
if std::mem::discriminant(ac_l.agg_state()) != std::mem::discriminant(ac_r.agg_state())
{
if state.verbose() {
eprintln!(
"ternary agg: finish as iters due to mix of AggregatedScalar and AggregatedList"
)
}
return finish_as_iters(ac_truthy, ac_falsy, ac_mask);
}
}
let ac_target = non_literal_acs.first().unwrap();
let agg_state_out = match ac_target.agg_state() {
AggregatedList(_) => {
if state.verbose() {
eprintln!("ternary agg: finish AggregatedList")
}
for (ac_l, ac_r) in non_literal_acs.iter().zip(non_literal_acs.iter().skip(1)) {
match (ac_l.agg_state(), ac_r.agg_state()) {
(AggregatedList(s_l), AggregatedList(s_r)) => {
let check = s_l.list().unwrap().offsets()?.as_slice()
== s_r.list().unwrap().offsets()?.as_slice();
polars_ensure!(
check,
ShapeMismatch: "shapes of `self`, `mask` and `other` are not suitable for `zip_with` operation"
);
},
_ => unreachable!(),
}
}
let truthy = if let AggregatedList(s) = ac_truthy.agg_state() {
s.list().unwrap().get_inner().into_column()
} else {
ac_truthy.get_values().clone()
};
let falsy = if let AggregatedList(s) = ac_falsy.agg_state() {
s.list().unwrap().get_inner().into_column()
} else {
ac_falsy.get_values().clone()
};
let mask = if let AggregatedList(s) = ac_mask.agg_state() {
s.list().unwrap().get_inner().into_column()
} else {
ac_mask.get_values().clone()
};
let out = truthy.zip_with(mask.bool()?, &falsy)?;
let out = out.rechunk();
let values = out.as_materialized_series().array_ref(0);
let offsets = ac_target.get_values().list().unwrap().offsets()?;
let inner_type = out.dtype();
let dtype = LargeListArray::default_datatype(values.dtype().clone());
let out = LargeListArray::new(dtype, offsets, values.clone(), None);
let mut out = ListChunked::with_chunk(truthy.name().clone(), out);
unsafe { out.to_logical(inner_type.clone()) };
if ac_target.get_values().list().unwrap()._can_fast_explode() {
out.set_fast_explode();
};
let out = out.into_column();
AggregatedList(out)
},
AggregatedScalar(_) => {
if state.verbose() {
eprintln!("ternary agg: finish AggregatedScalar")
}
let out = ac_truthy
.get_values()
.zip_with(ac_mask.get_values().bool()?, ac_falsy.get_values())?;
AggregatedScalar(out)
},
_ => {
unreachable!()
},
};
Ok(AggregationContext {
state: agg_state_out,
groups: ac_target.groups.clone(),
update_groups: ac_target.update_groups,
original_len: ac_target.original_len,
})
}
fn is_scalar(&self) -> bool {
self.returns_scalar
}
}