use std::iter;
use polars_utils::itertools::Itertools;
use super::*;
impl IR {
pub fn with_exprs<E>(mut self, exprs: E) -> Self
where
E: IntoIterator<Item = ExprIR>,
{
for (expr, new_expr) in self.exprs_mut().zip_eq(exprs) {
*expr = new_expr;
}
self
}
pub fn with_inputs<I>(mut self, inputs: I) -> Self
where
I: IntoIterator<Item = Node>,
{
for (input, new_input) in self.inputs_mut().zip_eq(inputs) {
*input = new_input;
}
self
}
pub fn exprs(&'_ self) -> Exprs<'_> {
use IR::*;
match self {
Slice { .. }
| Cache { .. }
| Distinct { .. }
| Union { .. }
| MapFunction { .. }
| DataFrameScan { .. }
| HConcat { .. }
| ExtContext { .. }
| SimpleProjection { .. }
| SinkMultiple { .. }
| Gather { .. } => Exprs::Empty,
#[cfg(feature = "merge_sorted")]
MergeSorted { .. } => Exprs::Empty,
#[cfg(feature = "python")]
PythonScan { options } => match &options.predicate {
PythonPredicate::Polars(predicate) => Exprs::single(predicate),
_ => Exprs::Empty,
},
Scan { predicate, .. } => match predicate {
Some(predicate) => Exprs::single(predicate),
_ => Exprs::Empty,
},
Filter { predicate, .. } => Exprs::single(predicate),
Sort { by_column, .. } => Exprs::slice(by_column),
Select { expr, .. } => Exprs::slice(expr),
HStack { exprs, .. } => Exprs::slice(exprs),
GroupBy { keys, aggs, .. } => Exprs::double_slice(keys, aggs),
Join {
left_on,
right_on,
options,
..
} => match &options.options {
Some(JoinTypeOptionsIR::CrossAndFilter { predicate }) => Exprs::Boxed(Box::new(
left_on
.iter()
.chain(right_on.iter())
.chain(iter::once(predicate)),
)),
_ => Exprs::double_slice(left_on, right_on),
},
Sink { payload, .. } => match payload {
SinkTypeIR::Memory => Exprs::Empty,
SinkTypeIR::Callback(_) => Exprs::Empty,
SinkTypeIR::File(_) => Exprs::Empty,
SinkTypeIR::Partitioned(PartitionedSinkOptionsIR {
partition_strategy, ..
}) => match partition_strategy {
PartitionStrategyIR::Keyed {
keys,
include_keys: _,
keys_pre_grouped: _,
} => Exprs::Slice(keys.iter()),
PartitionStrategyIR::FileSize => Exprs::Empty,
},
},
UnoptimizedDispatch { .. } => Exprs::Empty,
Invalid => unreachable!(),
}
}
pub fn exprs_mut(&'_ mut self) -> ExprsMut<'_> {
use IR::*;
match self {
Slice { .. }
| Cache { .. }
| Distinct { .. }
| Union { .. }
| MapFunction { .. }
| DataFrameScan { .. }
| HConcat { .. }
| ExtContext { .. }
| SimpleProjection { .. }
| SinkMultiple { .. }
| Gather { .. } => ExprsMut::Empty,
#[cfg(feature = "merge_sorted")]
MergeSorted { .. } => ExprsMut::Empty,
#[cfg(feature = "python")]
PythonScan { options } => match &mut options.predicate {
PythonPredicate::Polars(predicate) => ExprsMut::single(predicate),
_ => ExprsMut::Empty,
},
Scan { predicate, .. } => match predicate {
Some(predicate) => ExprsMut::single(predicate),
_ => ExprsMut::Empty,
},
Filter { predicate, .. } => ExprsMut::single(predicate),
Sort { by_column, .. } => ExprsMut::slice(by_column),
Select { expr, .. } => ExprsMut::slice(expr),
HStack { exprs, .. } => ExprsMut::slice(exprs),
GroupBy { keys, aggs, .. } => ExprsMut::double_slice(keys, aggs),
Join {
left_on,
right_on,
options,
..
} => match Arc::make_mut(options).options.as_mut() {
Some(JoinTypeOptionsIR::CrossAndFilter { predicate }) => ExprsMut::Boxed(Box::new(
left_on
.iter_mut()
.chain(right_on.iter_mut())
.chain(iter::once(predicate)),
)),
_ => ExprsMut::double_slice(left_on, right_on),
},
Sink { payload, .. } => match payload {
SinkTypeIR::Memory => ExprsMut::Empty,
SinkTypeIR::Callback(_) => ExprsMut::Empty,
SinkTypeIR::File(_) => ExprsMut::Empty,
SinkTypeIR::Partitioned(PartitionedSinkOptionsIR {
partition_strategy, ..
}) => match partition_strategy {
PartitionStrategyIR::Keyed {
keys,
include_keys: _,
keys_pre_grouped: _,
} => ExprsMut::Slice(keys.iter_mut()),
PartitionStrategyIR::FileSize => ExprsMut::Empty,
},
},
UnoptimizedDispatch { .. } => ExprsMut::Empty,
Invalid => unreachable!(),
}
}
pub fn copy_exprs<T>(&self, container: &mut T)
where
T: Extend<ExprIR>,
{
container.extend(self.exprs().cloned())
}
pub fn inputs(&self) -> Inputs<'_> {
use IR::*;
match self {
Union { inputs, .. } | HConcat { inputs, .. } | SinkMultiple { inputs } => {
Inputs::slice(inputs)
},
Slice { input, .. } => Inputs::single(*input),
Filter { input, .. } => Inputs::single(*input),
Select { input, .. } => Inputs::single(*input),
SimpleProjection { input, .. } => Inputs::single(*input),
Sort { input, .. } => Inputs::single(*input),
Cache { input, .. } => Inputs::single(*input),
GroupBy { input, .. } => Inputs::single(*input),
Join {
input_left,
input_right,
..
} => Inputs::double(*input_left, *input_right),
Gather { input, idxs, .. } => Inputs::double(*input, *idxs),
HStack { input, .. } => Inputs::single(*input),
Distinct { input, .. } => Inputs::single(*input),
MapFunction { input, .. } => Inputs::single(*input),
Sink { input, .. } => Inputs::single(*input),
ExtContext {
input, contexts, ..
} => Inputs::DoubleSlice(
std::slice::from_ref(input)
.iter()
.chain(contexts.iter())
.copied(),
),
Scan { .. } => Inputs::Empty,
DataFrameScan { .. } => Inputs::Empty,
#[cfg(feature = "python")]
PythonScan { .. } => Inputs::Empty,
#[cfg(feature = "merge_sorted")]
MergeSorted {
input_left,
input_right,
..
} => Inputs::double(*input_left, *input_right),
UnoptimizedDispatch { inputs, .. } => Inputs::slice(inputs),
Invalid => unreachable!(),
}
}
pub fn inputs_mut(&mut self) -> InputsMut<'_> {
use IR::*;
match self {
Union { inputs, .. } | HConcat { inputs, .. } | SinkMultiple { inputs } => {
InputsMut::slice(inputs)
},
Slice { input, .. } => InputsMut::single(input),
Filter { input, .. } => InputsMut::single(input),
Select { input, .. } => InputsMut::single(input),
SimpleProjection { input, .. } => InputsMut::single(input),
Sort { input, .. } => InputsMut::single(input),
Cache { input, .. } => InputsMut::single(input),
GroupBy { input, .. } => InputsMut::single(input),
Join {
input_left,
input_right,
..
} => InputsMut::double(input_left, input_right),
Gather { input, idxs, .. } => InputsMut::double(input, idxs),
HStack { input, .. } => InputsMut::single(input),
Distinct { input, .. } => InputsMut::single(input),
MapFunction { input, .. } => InputsMut::single(input),
Sink { input, .. } => InputsMut::single(input),
ExtContext {
input, contexts, ..
} => InputsMut::DoubleSlice(std::iter::chain(
std::slice::from_mut(input).iter_mut(),
contexts.iter_mut(),
)),
Scan { .. } => InputsMut::Empty,
DataFrameScan { .. } => InputsMut::Empty,
#[cfg(feature = "python")]
PythonScan { .. } => InputsMut::Empty,
#[cfg(feature = "merge_sorted")]
MergeSorted {
input_left,
input_right,
..
} => InputsMut::double(input_left, input_right),
UnoptimizedDispatch { inputs, .. } => InputsMut::slice(inputs),
Invalid => unreachable!(),
}
}
pub fn copy_inputs<T>(&self, container: &mut T)
where
T: Extend<Node>,
{
container.extend(self.inputs())
}
pub fn get_inputs(&self) -> UnitVec<Node> {
self.inputs().collect()
}
pub(crate) fn get_input(&self) -> Option<Node> {
self.inputs().next()
}
}
pub enum Inputs<'a> {
Empty,
Single(iter::Once<Node>),
Double(std::array::IntoIter<Node, 2>),
Slice(iter::Copied<std::slice::Iter<'a, Node>>),
DoubleSlice(iter::Copied<iter::Chain<std::slice::Iter<'a, Node>, std::slice::Iter<'a, Node>>>),
}
impl<'a> Inputs<'a> {
fn single(node: Node) -> Self {
Self::Single(iter::once(node))
}
fn double(left: Node, right: Node) -> Self {
Self::Double([left, right].into_iter())
}
fn slice(inputs: &'a [Node]) -> Self {
Self::Slice(inputs.iter().copied())
}
}
impl<'a> Iterator for Inputs<'a> {
type Item = Node;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Empty => None,
Self::Single(it) => it.next(),
Self::Double(it) => it.next(),
Self::Slice(it) => it.next(),
Self::DoubleSlice(it) => it.next(),
}
}
fn nth(&mut self, n: usize) -> Option<Self::Item> {
match self {
Self::Empty => None,
Self::Single(it) => it.nth(n),
Self::Double(it) => it.nth(n),
Self::Slice(it) => it.nth(n),
Self::DoubleSlice(it) => it.nth(n),
}
}
}
pub enum InputsMut<'a> {
Empty,
Single(iter::Once<&'a mut Node>),
Double(std::array::IntoIter<&'a mut Node, 2>),
Slice(std::slice::IterMut<'a, Node>),
DoubleSlice(iter::Chain<std::slice::IterMut<'a, Node>, std::slice::IterMut<'a, Node>>),
}
impl<'a> InputsMut<'a> {
fn single(node: &'a mut Node) -> Self {
Self::Single(iter::once(node))
}
fn double(left: &'a mut Node, right: &'a mut Node) -> Self {
Self::Double([left, right].into_iter())
}
fn slice(inputs: &'a mut [Node]) -> Self {
Self::Slice(inputs.iter_mut())
}
}
impl<'a> Iterator for InputsMut<'a> {
type Item = &'a mut Node;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Empty => None,
Self::Single(it) => it.next(),
Self::Double(it) => it.next(),
Self::Slice(it) => it.next(),
Self::DoubleSlice(it) => it.next(),
}
}
fn nth(&mut self, n: usize) -> Option<Self::Item> {
match self {
Self::Empty => None,
Self::Single(it) => it.nth(n),
Self::Double(it) => it.nth(n),
Self::Slice(it) => it.nth(n),
Self::DoubleSlice(it) => it.nth(n),
}
}
}
pub enum Exprs<'a> {
Empty,
Single(iter::Once<&'a ExprIR>),
Slice(std::slice::Iter<'a, ExprIR>),
DoubleSlice(iter::Chain<std::slice::Iter<'a, ExprIR>, std::slice::Iter<'a, ExprIR>>),
Boxed(Box<dyn Iterator<Item = &'a ExprIR> + 'a>),
}
impl<'a> Exprs<'a> {
fn single(expr: &'a ExprIR) -> Self {
Self::Single(iter::once(expr))
}
fn slice(inputs: &'a [ExprIR]) -> Self {
Self::Slice(inputs.iter())
}
fn double_slice(left: &'a [ExprIR], right: &'a [ExprIR]) -> Self {
Self::DoubleSlice(left.iter().chain(right.iter()))
}
}
impl<'a> Iterator for Exprs<'a> {
type Item = &'a ExprIR;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Empty => None,
Self::Single(it) => it.next(),
Self::Slice(it) => it.next(),
Self::DoubleSlice(it) => it.next(),
Self::Boxed(it) => it.next(),
}
}
}
pub enum ExprsMut<'a> {
Empty,
Single(iter::Once<&'a mut ExprIR>),
Slice(std::slice::IterMut<'a, ExprIR>),
DoubleSlice(iter::Chain<std::slice::IterMut<'a, ExprIR>, std::slice::IterMut<'a, ExprIR>>),
Boxed(Box<dyn Iterator<Item = &'a mut ExprIR> + 'a>),
}
impl<'a> ExprsMut<'a> {
fn single(expr: &'a mut ExprIR) -> Self {
Self::Single(iter::once(expr))
}
fn slice(inputs: &'a mut [ExprIR]) -> Self {
Self::Slice(inputs.iter_mut())
}
fn double_slice(left: &'a mut [ExprIR], right: &'a mut [ExprIR]) -> Self {
Self::DoubleSlice(left.iter_mut().chain(right.iter_mut()))
}
}
impl<'a> Iterator for ExprsMut<'a> {
type Item = &'a mut ExprIR;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Empty => None,
Self::Single(it) => it.next(),
Self::Slice(it) => it.next(),
Self::DoubleSlice(it) => it.next(),
Self::Boxed(it) => it.next(),
}
}
}