use std::{ops::Neg, sync::Arc};
use crate::expressions::Column;
use crate::utils::get_indices_of_matching_sort_exprs_with_order_eq;
use crate::{
EquivalenceProperties, OrderingEquivalenceProperties, PhysicalExpr, PhysicalSortExpr,
};
use arrow_schema::SortOptions;
use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
use datafusion_common::Result;
use itertools::Itertools;
#[derive(PartialEq, Debug, Clone, Copy)]
pub enum SortProperties {
Ordered(SortOptions),
Unordered,
Singleton,
}
impl SortProperties {
pub fn add(&self, rhs: &Self) -> Self {
match (self, rhs) {
(Self::Singleton, _) => *rhs,
(_, Self::Singleton) => *self,
(Self::Ordered(lhs), Self::Ordered(rhs))
if lhs.descending == rhs.descending =>
{
Self::Ordered(SortOptions {
descending: lhs.descending,
nulls_first: lhs.nulls_first || rhs.nulls_first,
})
}
_ => Self::Unordered,
}
}
pub fn sub(&self, rhs: &Self) -> Self {
match (self, rhs) {
(Self::Singleton, Self::Singleton) => Self::Singleton,
(Self::Singleton, Self::Ordered(rhs)) => Self::Ordered(SortOptions {
descending: !rhs.descending,
nulls_first: rhs.nulls_first,
}),
(_, Self::Singleton) => *self,
(Self::Ordered(lhs), Self::Ordered(rhs))
if lhs.descending != rhs.descending =>
{
Self::Ordered(SortOptions {
descending: lhs.descending,
nulls_first: lhs.nulls_first || rhs.nulls_first,
})
}
_ => Self::Unordered,
}
}
pub fn gt_or_gteq(&self, rhs: &Self) -> Self {
match (self, rhs) {
(Self::Singleton, Self::Ordered(rhs)) => Self::Ordered(SortOptions {
descending: !rhs.descending,
nulls_first: rhs.nulls_first,
}),
(_, Self::Singleton) => *self,
(Self::Ordered(lhs), Self::Ordered(rhs))
if lhs.descending != rhs.descending =>
{
*self
}
_ => Self::Unordered,
}
}
pub fn and(&self, rhs: &Self) -> Self {
match (self, rhs) {
(Self::Ordered(lhs), Self::Ordered(rhs))
if lhs.descending == rhs.descending =>
{
Self::Ordered(SortOptions {
descending: lhs.descending,
nulls_first: lhs.nulls_first || rhs.nulls_first,
})
}
(Self::Ordered(opt), Self::Singleton)
| (Self::Singleton, Self::Ordered(opt)) => Self::Ordered(SortOptions {
descending: opt.descending,
nulls_first: opt.nulls_first,
}),
(Self::Singleton, Self::Singleton) => Self::Singleton,
_ => Self::Unordered,
}
}
}
impl Neg for SortProperties {
type Output = Self;
fn neg(self) -> Self::Output {
match self {
SortProperties::Ordered(SortOptions {
descending,
nulls_first,
}) => SortProperties::Ordered(SortOptions {
descending: !descending,
nulls_first,
}),
SortProperties::Singleton => SortProperties::Singleton,
SortProperties::Unordered => SortProperties::Unordered,
}
}
}
#[derive(Debug)]
pub struct ExprOrdering {
pub expr: Arc<dyn PhysicalExpr>,
pub state: Option<SortProperties>,
pub children_states: Option<Vec<SortProperties>>,
}
impl ExprOrdering {
pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
Self {
expr,
state: None,
children_states: None,
}
}
pub fn children(&self) -> Vec<ExprOrdering> {
self.expr
.children()
.into_iter()
.map(ExprOrdering::new)
.collect()
}
pub fn new_with_children(
children_states: Vec<SortProperties>,
parent_expr: Arc<dyn PhysicalExpr>,
) -> Self {
Self {
expr: parent_expr,
state: None,
children_states: Some(children_states),
}
}
}
impl TreeNode for ExprOrdering {
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
where
F: FnMut(&Self) -> Result<VisitRecursion>,
{
for child in self.children() {
match op(&child)? {
VisitRecursion::Continue => {}
VisitRecursion::Skip => return Ok(VisitRecursion::Continue),
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
}
}
Ok(VisitRecursion::Continue)
}
fn map_children<F>(self, transform: F) -> Result<Self>
where
F: FnMut(Self) -> Result<Self>,
{
let children = self.children();
if children.is_empty() {
Ok(self)
} else {
Ok(ExprOrdering::new_with_children(
children
.into_iter()
.map(transform)
.map_ok(|c| c.state.unwrap_or(SortProperties::Unordered))
.collect::<Result<Vec<_>>>()?,
self.expr,
))
}
}
}
pub fn update_ordering(
mut node: ExprOrdering,
sort_expr: &PhysicalSortExpr,
equal_properties: &EquivalenceProperties,
ordering_equal_properties: &OrderingEquivalenceProperties,
) -> Result<Transformed<ExprOrdering>> {
if sort_expr.expr.eq(&node.expr) {
node.state = Some(SortProperties::Ordered(sort_expr.options));
return Ok(Transformed::Yes(node));
}
if let Some(children_sort_options) = &node.children_states {
node.state = Some(node.expr.get_ordering(children_sort_options));
} else if let Some(column) = node.expr.as_any().downcast_ref::<Column>() {
node.state = get_indices_of_matching_sort_exprs_with_order_eq(
&[sort_expr.clone()],
&[column.clone()],
equal_properties,
ordering_equal_properties,
)
.map(|(sort_options, _)| {
SortProperties::Ordered(SortOptions {
descending: sort_options[0].descending,
nulls_first: sort_options[0].nulls_first,
})
});
} else {
node.state = Some(node.expr.get_ordering(&[]));
}
Ok(Transformed::Yes(node))
}