use std::cmp::Ordering;
use std::sync::Arc;
use arrow::array::{
Array, ArrowPrimitiveType, GenericByteArray, GenericByteViewArray, OffsetSizeTrait,
PrimitiveArray, StringViewArray, types::ByteArrayType,
};
use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer};
use arrow::compute::SortOptions;
use arrow::datatypes::ArrowNativeTypeOp;
use arrow::row::Rows;
use datafusion_execution::memory_pool::MemoryReservation;
pub trait CursorValues {
fn len(&self) -> usize;
fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool;
fn eq_to_previous(cursor: &Self, idx: usize) -> bool;
fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering;
}
#[derive(Debug)]
pub struct Cursor<T: CursorValues> {
offset: usize,
values: T,
}
impl<T: CursorValues> Cursor<T> {
pub fn new(values: T) -> Self {
Self { offset: 0, values }
}
pub fn is_finished(&self) -> bool {
self.offset == self.values.len()
}
pub fn advance(&mut self) -> usize {
let t = self.offset;
self.offset += 1;
t
}
pub fn is_eq_to_prev_one(&self, prev_cursor: Option<&Cursor<T>>) -> bool {
if self.offset > 0 {
self.is_eq_to_prev_row()
} else if let Some(prev_cursor) = prev_cursor {
self.is_eq_to_prev_row_in_prev_batch(prev_cursor)
} else {
false
}
}
}
impl<T: CursorValues> PartialEq for Cursor<T> {
fn eq(&self, other: &Self) -> bool {
T::eq(&self.values, self.offset, &other.values, other.offset)
}
}
impl<T: CursorValues> Cursor<T> {
fn is_eq_to_prev_row(&self) -> bool {
T::eq_to_previous(&self.values, self.offset)
}
fn is_eq_to_prev_row_in_prev_batch(&self, other: &Self) -> bool {
assert_eq!(self.offset, 0);
T::eq(
&self.values,
self.offset,
&other.values,
other.values.len() - 1,
)
}
}
impl<T: CursorValues> Eq for Cursor<T> {}
impl<T: CursorValues> PartialOrd for Cursor<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T: CursorValues> Ord for Cursor<T> {
fn cmp(&self, other: &Self) -> Ordering {
T::compare(&self.values, self.offset, &other.values, other.offset)
}
}
#[derive(Debug)]
pub struct RowValues {
rows: Arc<Rows>,
_reservation: MemoryReservation,
}
impl RowValues {
pub fn new(rows: Arc<Rows>, reservation: MemoryReservation) -> Self {
assert_eq!(
rows.size(),
reservation.size(),
"memory reservation mismatch"
);
assert!(rows.num_rows() > 0);
Self {
rows,
_reservation: reservation,
}
}
}
impl CursorValues for RowValues {
fn len(&self) -> usize {
self.rows.num_rows()
}
fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
l.rows.row(l_idx) == r.rows.row(r_idx)
}
fn eq_to_previous(cursor: &Self, idx: usize) -> bool {
assert!(idx > 0);
cursor.rows.row(idx) == cursor.rows.row(idx - 1)
}
fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
l.rows.row(l_idx).cmp(&r.rows.row(r_idx))
}
}
pub trait CursorArray: Array + 'static {
type Values: CursorValues;
fn values(&self) -> Self::Values;
}
impl<T: ArrowPrimitiveType> CursorArray for PrimitiveArray<T> {
type Values = PrimitiveValues<T::Native>;
fn values(&self) -> Self::Values {
PrimitiveValues(self.values().clone())
}
}
#[derive(Debug)]
pub struct PrimitiveValues<T: ArrowNativeTypeOp>(ScalarBuffer<T>);
impl<T: ArrowNativeTypeOp> CursorValues for PrimitiveValues<T> {
fn len(&self) -> usize {
self.0.len()
}
fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
l.0[l_idx].is_eq(r.0[r_idx])
}
fn eq_to_previous(cursor: &Self, idx: usize) -> bool {
assert!(idx > 0);
cursor.0[idx].is_eq(cursor.0[idx - 1])
}
fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
l.0[l_idx].compare(r.0[r_idx])
}
}
pub struct ByteArrayValues<T: OffsetSizeTrait> {
offsets: OffsetBuffer<T>,
values: Buffer,
}
impl<T: OffsetSizeTrait> ByteArrayValues<T> {
fn value(&self, idx: usize) -> &[u8] {
assert!(idx < self.len());
unsafe {
let start = self.offsets.get_unchecked(idx).as_usize();
let end = self.offsets.get_unchecked(idx + 1).as_usize();
self.values.get_unchecked(start..end)
}
}
}
impl<T: OffsetSizeTrait> CursorValues for ByteArrayValues<T> {
fn len(&self) -> usize {
self.offsets.len() - 1
}
fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
l.value(l_idx) == r.value(r_idx)
}
fn eq_to_previous(cursor: &Self, idx: usize) -> bool {
assert!(idx > 0);
cursor.value(idx) == cursor.value(idx - 1)
}
fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
l.value(l_idx).cmp(r.value(r_idx))
}
}
impl<T: ByteArrayType> CursorArray for GenericByteArray<T> {
type Values = ByteArrayValues<T::Offset>;
fn values(&self) -> Self::Values {
ByteArrayValues {
offsets: self.offsets().clone(),
values: self.values().clone(),
}
}
}
impl CursorArray for StringViewArray {
type Values = StringViewArray;
fn values(&self) -> Self {
self.gc()
}
}
impl CursorValues for StringViewArray {
fn len(&self) -> usize {
self.views().len()
}
#[inline(always)]
fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
let l_view = unsafe { l.views().get_unchecked(l_idx) };
let r_view = unsafe { r.views().get_unchecked(r_idx) };
if l.data_buffers().is_empty() && r.data_buffers().is_empty() {
return l_view == r_view;
}
let l_len = *l_view as u32;
let r_len = *r_view as u32;
if l_len != r_len {
return false;
}
unsafe { GenericByteViewArray::compare_unchecked(l, l_idx, r, r_idx).is_eq() }
}
#[inline(always)]
fn eq_to_previous(cursor: &Self, idx: usize) -> bool {
let l_view = unsafe { cursor.views().get_unchecked(idx) };
let r_view = unsafe { cursor.views().get_unchecked(idx - 1) };
if cursor.data_buffers().is_empty() {
return l_view == r_view;
}
let l_len = *l_view as u32;
let r_len = *r_view as u32;
if l_len != r_len {
return false;
}
unsafe {
GenericByteViewArray::compare_unchecked(cursor, idx, cursor, idx - 1).is_eq()
}
}
#[inline(always)]
fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
if l.data_buffers().is_empty() && r.data_buffers().is_empty() {
let l_view = unsafe { l.views().get_unchecked(l_idx) };
let r_view = unsafe { r.views().get_unchecked(r_idx) };
return StringViewArray::inline_key_fast(*l_view)
.cmp(&StringViewArray::inline_key_fast(*r_view));
}
unsafe { GenericByteViewArray::compare_unchecked(l, l_idx, r, r_idx) }
}
}
#[derive(Debug)]
pub struct ArrayValues<T: CursorValues> {
values: T,
null_threshold: usize,
options: SortOptions,
_reservation: MemoryReservation,
}
impl<T: CursorValues> ArrayValues<T> {
pub fn new<A: CursorArray<Values = T>>(
options: SortOptions,
array: &A,
reservation: MemoryReservation,
) -> Self {
assert!(array.len() > 0, "Empty array passed to FieldCursor");
let null_threshold = match options.nulls_first {
true => array.null_count(),
false => array.len() - array.null_count(),
};
Self {
values: array.values(),
null_threshold,
options,
_reservation: reservation,
}
}
fn is_null(&self, idx: usize) -> bool {
(idx < self.null_threshold) == self.options.nulls_first
}
}
impl<T: CursorValues> CursorValues for ArrayValues<T> {
fn len(&self) -> usize {
self.values.len()
}
fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
match (l.is_null(l_idx), r.is_null(r_idx)) {
(true, true) => true,
(false, false) => T::eq(&l.values, l_idx, &r.values, r_idx),
_ => false,
}
}
fn eq_to_previous(cursor: &Self, idx: usize) -> bool {
assert!(idx > 0);
match (cursor.is_null(idx), cursor.is_null(idx - 1)) {
(true, true) => true,
(false, false) => T::eq(&cursor.values, idx, &cursor.values, idx - 1),
_ => false,
}
}
fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
match (l.is_null(l_idx), r.is_null(r_idx)) {
(true, true) => Ordering::Equal,
(true, false) => match l.options.nulls_first {
true => Ordering::Less,
false => Ordering::Greater,
},
(false, true) => match l.options.nulls_first {
true => Ordering::Greater,
false => Ordering::Less,
},
(false, false) => match l.options.descending {
true => T::compare(&r.values, r_idx, &l.values, l_idx),
false => T::compare(&l.values, l_idx, &r.values, r_idx),
},
}
}
}
#[cfg(test)]
mod tests {
use datafusion_execution::memory_pool::{
GreedyMemoryPool, MemoryConsumer, MemoryPool,
};
use std::sync::Arc;
use super::*;
fn new_primitive(
options: SortOptions,
values: ScalarBuffer<i32>,
null_count: usize,
) -> Cursor<ArrayValues<PrimitiveValues<i32>>> {
let null_threshold = match options.nulls_first {
true => null_count,
false => values.len() - null_count,
};
let memory_pool: Arc<dyn MemoryPool> = Arc::new(GreedyMemoryPool::new(10000));
let consumer = MemoryConsumer::new("test");
let reservation = consumer.register(&memory_pool);
let values = ArrayValues {
values: PrimitiveValues(values),
null_threshold,
options,
_reservation: reservation,
};
Cursor::new(values)
}
#[test]
fn test_primitive_nulls_first() {
let options = SortOptions {
descending: false,
nulls_first: true,
};
let buffer = ScalarBuffer::from(vec![i32::MAX, 1, 2, 3]);
let mut a = new_primitive(options, buffer, 1);
let buffer = ScalarBuffer::from(vec![1, 2, -2, -1, 1, 9]);
let mut b = new_primitive(options, buffer, 2);
assert_eq!(a.cmp(&b), Ordering::Equal);
assert_eq!(a, b);
b.advance();
assert_eq!(a.cmp(&b), Ordering::Equal);
assert_eq!(a, b);
b.advance();
assert_eq!(a.cmp(&b), Ordering::Less);
a.advance();
assert_eq!(a.cmp(&b), Ordering::Greater);
b.advance();
assert_eq!(a.cmp(&b), Ordering::Greater);
b.advance();
assert_eq!(a.cmp(&b), Ordering::Equal);
assert_eq!(a, b);
b.advance();
assert_eq!(a.cmp(&b), Ordering::Less);
a.advance();
assert_eq!(a.cmp(&b), Ordering::Less);
let options = SortOptions {
descending: false,
nulls_first: false,
};
let buffer = ScalarBuffer::from(vec![0, 1, i32::MIN, i32::MAX]);
let mut a = new_primitive(options, buffer, 2);
let buffer = ScalarBuffer::from(vec![-1, i32::MAX, i32::MIN]);
let mut b = new_primitive(options, buffer, 2);
assert_eq!(a.cmp(&b), Ordering::Greater);
b.advance();
assert_eq!(a.cmp(&b), Ordering::Less);
a.advance();
assert_eq!(a.cmp(&b), Ordering::Less);
a.advance();
assert_eq!(a.cmp(&b), Ordering::Equal);
assert_eq!(a, b);
let options = SortOptions {
descending: true,
nulls_first: false,
};
let buffer = ScalarBuffer::from(vec![6, 1, i32::MIN, i32::MAX]);
let mut a = new_primitive(options, buffer, 3);
let buffer = ScalarBuffer::from(vec![67, -3, i32::MAX, i32::MIN]);
let mut b = new_primitive(options, buffer, 2);
assert_eq!(a.cmp(&b), Ordering::Greater);
b.advance();
assert_eq!(a.cmp(&b), Ordering::Less);
b.advance();
assert_eq!(a.cmp(&b), Ordering::Less);
b.advance();
assert_eq!(a.cmp(&b), Ordering::Less);
a.advance();
assert_eq!(a.cmp(&b), Ordering::Equal);
assert_eq!(a, b);
let options = SortOptions {
descending: true,
nulls_first: true,
};
let buffer = ScalarBuffer::from(vec![i32::MIN, i32::MAX, 6, 3]);
let mut a = new_primitive(options, buffer, 2);
let buffer = ScalarBuffer::from(vec![i32::MAX, 4546, -3]);
let mut b = new_primitive(options, buffer, 1);
assert_eq!(a.cmp(&b), Ordering::Equal);
assert_eq!(a, b);
a.advance();
assert_eq!(a.cmp(&b), Ordering::Equal);
assert_eq!(a, b);
b.advance();
assert_eq!(a.cmp(&b), Ordering::Less);
a.advance();
assert_eq!(a.cmp(&b), Ordering::Greater);
b.advance();
assert_eq!(a.cmp(&b), Ordering::Less);
}
}