use arrow::legacy::utils::CustomIterTools;
use polars_core::chunked_array::builder::get_list_builder;
use polars_core::prelude::*;
use polars_core::utils::NoNull;
use polars_ops::prelude::{convert_to_unsigned_index, is_positive_idx_uncertain_col};
use super::*;
use crate::expressions::{AggState, AggregationContext, PhysicalExpr, UpdateGroups};
pub struct GatherExpr {
pub(crate) phys_expr: Arc<dyn PhysicalExpr>,
pub(crate) idx: Arc<dyn PhysicalExpr>,
pub(crate) expr: Expr,
pub(crate) returns_scalar: bool,
}
impl PhysicalExpr for GatherExpr {
fn as_expression(&self) -> Option<&Expr> {
Some(&self.expr)
}
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
let series = self.phys_expr.evaluate(df, state)?;
self.finish(df, state, series)
}
#[allow(clippy::ptr_arg)]
fn evaluate_on_groups<'a>(
&self,
df: &DataFrame,
groups: &'a GroupPositions,
state: &ExecutionState,
) -> PolarsResult<AggregationContext<'a>> {
let mut ac = self.phys_expr.evaluate_on_groups(df, groups, state)?;
let mut idx = self.idx.evaluate_on_groups(df, groups, state)?;
let c_idx = idx.get_values();
match c_idx.dtype() {
DataType::List(inner) => {
polars_ensure!(inner.is_integer(), InvalidOperation: "expected numeric dtype as index, got {:?}", inner)
},
dt if dt.is_integer() => {
if !is_positive_idx_uncertain_col(c_idx) {
return self.process_negative_indices_agg(ac, idx, groups);
}
},
dt => polars_bail!(InvalidOperation: "expected numeric dtype as index, got {:?}", dt),
}
let idx = match idx.state {
AggState::AggregatedScalar(s) => {
let idx = s.cast(&IDX_DTYPE)?;
return self.process_positive_indices_agg_scalar(ac, idx.idx().unwrap());
},
AggState::AggregatedList(s) => {
polars_ensure!(!self.returns_scalar, ComputeError: "expected single index");
s.list().unwrap().clone()
},
AggState::NotAggregated(_) => {
polars_ensure!(!self.returns_scalar, ComputeError: "expected single index");
let s = idx.aggregated();
s.list().unwrap().clone()
},
AggState::Literal(s) => {
let idx = s.cast(&IDX_DTYPE)?;
return self.process_positive_indices_agg_literal(ac, idx.idx().unwrap());
},
};
let s = idx.cast(&DataType::List(Box::new(IDX_DTYPE)))?;
let idx = s.list().unwrap();
let taken = {
ac.aggregated()
.list()
.unwrap()
.amortized_iter()
.zip(idx.amortized_iter())
.map(|(s, idx)| Some(s?.as_ref().take(idx?.as_ref().idx().unwrap())))
.map(|opt_res| opt_res.transpose())
.collect::<PolarsResult<ListChunked>>()?
.with_name(ac.get_values().name().clone())
};
ac.with_values(taken.into_column(), true, Some(&self.expr))?;
ac.with_update_groups(UpdateGroups::WithSeriesLen);
Ok(ac)
}
fn collect_live_columns(&self, lv: &mut PlIndexSet<PlSmallStr>) {
self.phys_expr.collect_live_columns(lv);
self.idx.collect_live_columns(lv);
}
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.phys_expr.to_field(input_schema)
}
fn is_scalar(&self) -> bool {
self.returns_scalar
}
}
impl GatherExpr {
fn finish(
&self,
df: &DataFrame,
state: &ExecutionState,
series: Column,
) -> PolarsResult<Column> {
let idx = self.idx.evaluate(df, state)?;
let idx = convert_to_unsigned_index(idx.as_materialized_series(), series.len())?;
series.take(&idx)
}
fn oob_err(&self) -> PolarsResult<()> {
polars_bail!(expr = self.expr, OutOfBounds: "index out of bounds");
}
fn process_positive_indices_agg_scalar<'b>(
&self,
mut ac: AggregationContext<'b>,
idx: &IdxCa,
) -> PolarsResult<AggregationContext<'b>> {
if ac.is_not_aggregated() {
let groups = ac.groups();
let idx: IdxCa = match groups.as_ref().as_ref() {
GroupsType::Idx(groups) => {
if groups.all().iter().zip(idx).any(|(g, idx)| match idx {
None => false,
Some(idx) => idx >= g.len() as IdxSize,
}) {
self.oob_err()?;
}
idx.into_iter()
.zip(groups.iter())
.map(|(idx, (_first, groups))| {
idx.map(|idx| {
unsafe { *groups.get_unchecked(usize::try_from(idx).unwrap()) }
})
})
.collect_trusted()
},
GroupsType::Slice { groups, .. } => {
if groups.iter().zip(idx).any(|(g, idx)| match idx {
None => false,
Some(idx) => idx >= g[1],
}) {
self.oob_err()?;
}
idx.into_iter()
.zip(groups.iter())
.map(|(idx, g)| idx.map(|idx| idx + g[0]))
.collect_trusted()
},
};
let taken = ac.flat_naive().take(&idx)?;
let taken = if self.returns_scalar {
taken
} else {
taken.as_list().into_column()
};
ac.with_values(taken, true, Some(&self.expr))?;
Ok(ac)
} else {
self.gather_aggregated_expensive(ac, idx)
}
}
fn gather_aggregated_expensive<'b>(
&self,
mut ac: AggregationContext<'b>,
idx: &IdxCa,
) -> PolarsResult<AggregationContext<'b>> {
let out = ac
.aggregated()
.list()
.unwrap()
.try_apply_amortized(|s| s.as_ref().take(idx))?;
ac.with_values(out.into_column(), true, Some(&self.expr))?;
ac.with_update_groups(UpdateGroups::WithGroupsLen);
Ok(ac)
}
fn process_positive_indices_agg_literal<'b>(
&self,
mut ac: AggregationContext<'b>,
idx: &IdxCa,
) -> PolarsResult<AggregationContext<'b>> {
if idx.len() == 1 {
match idx.get(0) {
None => polars_bail!(ComputeError: "cannot take by a null"),
Some(idx) => {
let groups = ac.groups();
let idx: NoNull<IdxCa> = match groups.as_ref().as_ref() {
GroupsType::Idx(groups) => {
if groups.all().iter().any(|g| idx >= g.len() as IdxSize) {
self.oob_err()?;
}
groups
.iter()
.map(|(_, group)| {
unsafe { *group.get_unchecked(idx as usize) }
})
.collect_trusted()
},
GroupsType::Slice { groups, .. } => {
if groups.iter().any(|g| idx >= g[1]) {
self.oob_err()?;
}
groups.iter().map(|g| g[0] + idx).collect_trusted()
},
};
let taken = ac.flat_naive().take(&idx.into_inner())?;
let taken = if self.returns_scalar {
taken
} else {
taken.as_list().into_column()
};
ac.with_values(taken, true, Some(&self.expr))?;
ac.with_update_groups(UpdateGroups::WithGroupsLen);
Ok(ac)
},
}
} else {
self.gather_aggregated_expensive(ac, idx)
}
}
fn process_negative_indices_agg<'b>(
&self,
mut ac: AggregationContext<'b>,
mut idx: AggregationContext<'b>,
groups: &'b GroupsType,
) -> PolarsResult<AggregationContext<'b>> {
let mut builder = get_list_builder(
&ac.dtype(),
idx.get_values().len(),
groups.len(),
ac.get_values().name().clone(),
);
let iter = ac.iter_groups(false).zip(idx.iter_groups(false));
for (s, idx) in iter {
match (s, idx) {
(Some(s), Some(idx)) => {
let idx = convert_to_unsigned_index(idx.as_ref(), s.as_ref().len())?;
let out = s.as_ref().take(&idx)?;
builder.append_series(&out)?;
},
_ => builder.append_null(),
};
}
let out = builder.finish().into_column();
ac.with_agg_state(AggState::AggregatedList(out));
Ok(ac)
}
}