use std::cmp::Ordering;
use std::fmt::{self, Display, Formatter};
use std::hash::{Hash, Hasher};
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::vec::IntoIter;
use crate::physical_expr::{PhysicalExpr, fmt_sql};
use arrow::compute::kernels::sort::{SortColumn, SortOptions};
use arrow::datatypes::Schema;
use arrow::record_batch::RecordBatch;
use datafusion_common::{HashSet, Result};
use datafusion_expr_common::columnar_value::ColumnarValue;
use indexmap::IndexSet;
#[derive(Clone, Debug, Eq)]
pub struct PhysicalSortExpr {
pub expr: Arc<dyn PhysicalExpr>,
pub options: SortOptions,
}
impl PhysicalSortExpr {
pub fn new(expr: Arc<dyn PhysicalExpr>, options: SortOptions) -> Self {
Self { expr, options }
}
pub fn new_default(expr: Arc<dyn PhysicalExpr>) -> Self {
Self::new(expr, SortOptions::default())
}
pub fn reverse(&self) -> Self {
let mut result = self.clone();
result.options = !result.options;
result
}
pub fn asc(mut self) -> Self {
self.options.descending = false;
self
}
pub fn desc(mut self) -> Self {
self.options.descending = true;
self
}
pub fn nulls_first(mut self) -> Self {
self.options.nulls_first = true;
self
}
pub fn nulls_last(mut self) -> Self {
self.options.nulls_first = false;
self
}
pub fn fmt_sql(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"{} {}",
fmt_sql(self.expr.as_ref()),
to_str(&self.options)
)
}
pub fn evaluate_to_sort_column(&self, batch: &RecordBatch) -> Result<SortColumn> {
let array_to_sort = match self.expr.evaluate(batch)? {
ColumnarValue::Array(array) => array,
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(batch.num_rows())?,
};
Ok(SortColumn {
values: array_to_sort,
options: Some(self.options),
})
}
pub fn satisfy(
&self,
requirement: &PhysicalSortRequirement,
schema: &Schema,
) -> bool {
self.expr.eq(&requirement.expr)
&& requirement.options.is_none_or(|opts| {
options_compatible(
&self.options,
&opts,
self.expr.nullable(schema).unwrap_or(true),
)
})
}
pub fn satisfy_expr(&self, sort_expr: &Self, schema: &Schema) -> bool {
self.expr.eq(&sort_expr.expr)
&& options_compatible(
&self.options,
&sort_expr.options,
self.expr.nullable(schema).unwrap_or(true),
)
}
}
impl PartialEq for PhysicalSortExpr {
fn eq(&self, other: &Self) -> bool {
self.options == other.options && self.expr.eq(&other.expr)
}
}
impl Hash for PhysicalSortExpr {
fn hash<H: Hasher>(&self, state: &mut H) {
self.expr.hash(state);
self.options.hash(state);
}
}
impl Display for PhysicalSortExpr {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{} {}", self.expr, to_str(&self.options))
}
}
pub fn options_compatible(
options_lhs: &SortOptions,
options_rhs: &SortOptions,
nullable: bool,
) -> bool {
if nullable {
options_lhs == options_rhs
} else {
options_lhs.descending == options_rhs.descending
}
}
#[derive(Clone, Debug)]
pub struct PhysicalSortRequirement {
pub expr: Arc<dyn PhysicalExpr>,
pub options: Option<SortOptions>,
}
impl PartialEq for PhysicalSortRequirement {
fn eq(&self, other: &Self) -> bool {
self.options == other.options && self.expr.eq(&other.expr)
}
}
impl Display for PhysicalSortRequirement {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let opts_string = self.options.as_ref().map_or("NA", to_str);
write!(f, "{} {}", self.expr, opts_string)
}
}
pub fn format_physical_sort_requirement_list(
exprs: &[PhysicalSortRequirement],
) -> impl Display + '_ {
struct DisplayWrapper<'a>(&'a [PhysicalSortRequirement]);
impl Display for DisplayWrapper<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut iter = self.0.iter();
write!(f, "[")?;
if let Some(expr) = iter.next() {
write!(f, "{expr}")?;
}
for expr in iter {
write!(f, ", {expr}")?;
}
write!(f, "]")?;
Ok(())
}
}
DisplayWrapper(exprs)
}
impl PhysicalSortRequirement {
pub fn new(expr: Arc<dyn PhysicalExpr>, options: Option<SortOptions>) -> Self {
Self { expr, options }
}
pub fn compatible(&self, other: &Self) -> bool {
self.expr.eq(&other.expr)
&& other
.options
.is_none_or(|other_opts| self.options == Some(other_opts))
}
}
#[inline]
fn to_str(options: &SortOptions) -> &str {
match (options.descending, options.nulls_first) {
(true, true) => "DESC",
(true, false) => "DESC NULLS LAST",
(false, true) => "ASC",
(false, false) => "ASC NULLS LAST",
}
}
impl From<PhysicalSortExpr> for PhysicalSortRequirement {
fn from(value: PhysicalSortExpr) -> Self {
Self::new(value.expr, Some(value.options))
}
}
impl From<PhysicalSortRequirement> for PhysicalSortExpr {
fn from(value: PhysicalSortRequirement) -> Self {
let options = value
.options
.unwrap_or_else(|| SortOptions::new(false, false));
Self::new(value.expr, options)
}
}
#[derive(Clone, Debug)]
pub struct LexOrdering {
exprs: Vec<PhysicalSortExpr>,
set: IndexSet<Arc<dyn PhysicalExpr>>,
}
impl LexOrdering {
pub fn new(exprs: impl IntoIterator<Item = PhysicalSortExpr>) -> Option<Self> {
let exprs = exprs.into_iter();
let mut candidate = Self {
exprs: Vec::new(),
set: IndexSet::new(),
};
for expr in exprs {
candidate.push(expr);
}
if candidate.exprs.is_empty() {
None
} else {
Some(candidate)
}
}
pub fn push(&mut self, sort_expr: PhysicalSortExpr) {
if self.set.insert(Arc::clone(&sort_expr.expr)) {
self.exprs.push(sort_expr);
}
}
pub fn extend(&mut self, sort_exprs: impl IntoIterator<Item = PhysicalSortExpr>) {
for sort_expr in sort_exprs {
self.push(sort_expr);
}
}
pub fn first(&self) -> &PhysicalSortExpr {
self.exprs.first().unwrap()
}
pub fn capacity(&self) -> usize {
self.exprs.capacity()
}
pub fn truncate(&mut self, len: usize) -> bool {
if len == 0 || len >= self.exprs.len() {
return false;
}
for PhysicalSortExpr { expr, .. } in self.exprs[len..].iter() {
self.set.swap_remove(expr);
}
self.exprs.truncate(len);
true
}
pub fn is_reverse(&self, other: &LexOrdering) -> bool {
let self_exprs = self.as_ref();
let other_exprs = other.as_ref();
if other_exprs.len() > self_exprs.len() {
return false;
}
other_exprs.iter().zip(self_exprs.iter()).all(|(req, cur)| {
req.expr.eq(&cur.expr) && is_reversed_sort_options(&req.options, &cur.options)
})
}
pub fn get_sort_options(&self, expr: &dyn PhysicalExpr) -> Option<SortOptions> {
for e in self {
if e.expr.as_ref().dyn_eq(expr) {
return Some(e.options);
}
}
None
}
}
pub fn is_reversed_sort_options(lhs: &SortOptions, rhs: &SortOptions) -> bool {
lhs.descending != rhs.descending && lhs.nulls_first != rhs.nulls_first
}
impl PartialEq for LexOrdering {
fn eq(&self, other: &Self) -> bool {
let Self {
exprs,
set: _, } = self;
exprs == &other.exprs
}
}
impl Eq for LexOrdering {}
impl PartialOrd for LexOrdering {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.exprs
.iter()
.zip(other.exprs.iter())
.all(|(lhs, rhs)| lhs == rhs)
.then(|| self.len().cmp(&other.len()))
}
}
impl<const N: usize> From<[PhysicalSortExpr; N]> for LexOrdering {
fn from(value: [PhysicalSortExpr; N]) -> Self {
assert!(N > 0);
Self::new(value)
.expect("A LexOrdering from non-empty array must be non-degenerate")
}
}
impl Deref for LexOrdering {
type Target = [PhysicalSortExpr];
fn deref(&self) -> &Self::Target {
self.exprs.as_slice()
}
}
impl Display for LexOrdering {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let mut first = true;
for sort_expr in &self.exprs {
if first {
first = false;
} else {
write!(f, ", ")?;
}
write!(f, "{sort_expr}")?;
}
Ok(())
}
}
impl IntoIterator for LexOrdering {
type Item = PhysicalSortExpr;
type IntoIter = IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
self.exprs.into_iter()
}
}
impl<'a> IntoIterator for &'a LexOrdering {
type Item = &'a PhysicalSortExpr;
type IntoIter = std::slice::Iter<'a, PhysicalSortExpr>;
fn into_iter(self) -> Self::IntoIter {
self.exprs.iter()
}
}
impl From<LexOrdering> for Vec<PhysicalSortExpr> {
fn from(ordering: LexOrdering) -> Self {
ordering.exprs
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct LexRequirement {
reqs: Vec<PhysicalSortRequirement>,
}
impl LexRequirement {
pub fn new(reqs: impl IntoIterator<Item = PhysicalSortRequirement>) -> Option<Self> {
let (non_empty, requirements) = Self::construct(reqs);
non_empty.then_some(requirements)
}
pub fn first(&self) -> &PhysicalSortRequirement {
self.reqs.first().unwrap()
}
fn construct(
reqs: impl IntoIterator<Item = PhysicalSortRequirement>,
) -> (bool, Self) {
let mut set = HashSet::new();
let reqs = reqs
.into_iter()
.filter_map(|r| set.insert(Arc::clone(&r.expr)).then_some(r))
.collect();
(!set.is_empty(), Self { reqs })
}
}
impl<const N: usize> From<[PhysicalSortRequirement; N]> for LexRequirement {
fn from(value: [PhysicalSortRequirement; N]) -> Self {
assert!(N > 0);
let (non_empty, requirement) = Self::construct(value);
debug_assert!(non_empty);
requirement
}
}
impl Deref for LexRequirement {
type Target = [PhysicalSortRequirement];
fn deref(&self) -> &Self::Target {
self.reqs.as_slice()
}
}
impl IntoIterator for LexRequirement {
type Item = PhysicalSortRequirement;
type IntoIter = IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
self.reqs.into_iter()
}
}
impl<'a> IntoIterator for &'a LexRequirement {
type Item = &'a PhysicalSortRequirement;
type IntoIter = std::slice::Iter<'a, PhysicalSortRequirement>;
fn into_iter(self) -> Self::IntoIter {
self.reqs.iter()
}
}
impl From<LexRequirement> for Vec<PhysicalSortRequirement> {
fn from(requirement: LexRequirement) -> Self {
requirement.reqs
}
}
impl From<LexOrdering> for LexRequirement {
fn from(value: LexOrdering) -> Self {
let (non_empty, requirements) =
Self::construct(value.into_iter().map(Into::into));
debug_assert!(non_empty);
requirements
}
}
impl From<LexRequirement> for LexOrdering {
fn from(value: LexRequirement) -> Self {
Self::new(value.into_iter().map(Into::into))
.expect("A LexOrdering from LexRequirement must be non-degenerate")
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum OrderingRequirements {
Hard(Vec<LexRequirement>),
Soft(Vec<LexRequirement>),
}
impl OrderingRequirements {
pub fn new_alternatives(
alternatives: impl IntoIterator<Item = LexRequirement>,
soft: bool,
) -> Option<Self> {
let alternatives = alternatives.into_iter().collect::<Vec<_>>();
(!alternatives.is_empty()).then(|| {
if soft {
Self::Soft(alternatives)
} else {
Self::Hard(alternatives)
}
})
}
pub fn new(requirement: LexRequirement) -> Self {
Self::Hard(vec![requirement])
}
pub fn new_soft(requirement: LexRequirement) -> Self {
Self::Soft(vec![requirement])
}
pub fn add_alternative(&mut self, requirement: LexRequirement) {
match self {
Self::Hard(alts) | Self::Soft(alts) => alts.push(requirement),
}
}
pub fn into_single(self) -> LexRequirement {
match self {
Self::Hard(mut alts) | Self::Soft(mut alts) => alts.swap_remove(0),
}
}
pub fn first(&self) -> &LexRequirement {
match self {
Self::Hard(alts) | Self::Soft(alts) => &alts[0],
}
}
pub fn into_alternatives(self) -> (Vec<LexRequirement>, bool) {
match self {
Self::Hard(alts) => (alts, false),
Self::Soft(alts) => (alts, true),
}
}
}
impl From<LexRequirement> for OrderingRequirements {
fn from(requirement: LexRequirement) -> Self {
Self::new(requirement)
}
}
impl From<LexOrdering> for OrderingRequirements {
fn from(ordering: LexOrdering) -> Self {
Self::new(ordering.into())
}
}
impl Deref for OrderingRequirements {
type Target = [LexRequirement];
fn deref(&self) -> &Self::Target {
match &self {
Self::Hard(alts) | Self::Soft(alts) => alts.as_slice(),
}
}
}
impl DerefMut for OrderingRequirements {
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
Self::Hard(alts) | Self::Soft(alts) => alts.as_mut_slice(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_reversed_sort_options() {
let asc_nulls_last = SortOptions {
descending: false,
nulls_first: false,
};
let desc_nulls_first = SortOptions {
descending: true,
nulls_first: true,
};
assert!(is_reversed_sort_options(&asc_nulls_last, &desc_nulls_first));
assert!(is_reversed_sort_options(&desc_nulls_first, &asc_nulls_last));
let asc_nulls_first = SortOptions {
descending: false,
nulls_first: true,
};
let desc_nulls_last = SortOptions {
descending: true,
nulls_first: false,
};
assert!(is_reversed_sort_options(&asc_nulls_first, &desc_nulls_last));
assert!(is_reversed_sort_options(&desc_nulls_last, &asc_nulls_first));
assert!(!is_reversed_sort_options(&asc_nulls_last, &asc_nulls_last));
assert!(!is_reversed_sort_options(
&desc_nulls_first,
&desc_nulls_first
));
assert!(!is_reversed_sort_options(&asc_nulls_last, &desc_nulls_last));
assert!(!is_reversed_sort_options(&desc_nulls_last, &asc_nulls_last));
assert!(!is_reversed_sort_options(&asc_nulls_last, &asc_nulls_first));
assert!(!is_reversed_sort_options(&asc_nulls_first, &asc_nulls_last));
}
}