use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use arrow_array::cast::*;
use arrow_array::*;
use crate::compute::SortOptions;
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
use crate::row::dictionary::{
compute_dictionary_mapping, decode_dictionary, encode_dictionary,
};
use crate::row::fixed::{decode_bool, decode_primitive};
use crate::row::interner::OrderPreservingInterner;
use crate::row::variable::{decode_binary, decode_string};
use crate::{downcast_dictionary_array, downcast_primitive_array};
mod dictionary;
mod fixed;
mod interner;
mod variable;
#[derive(Debug)]
pub struct RowConverter {
fields: Arc<[SortField]>,
interners: Vec<Option<Box<OrderPreservingInterner>>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SortField {
options: SortOptions,
data_type: DataType,
}
impl SortField {
pub fn new(data_type: DataType) -> Self {
Self::new_with_options(data_type, Default::default())
}
pub fn new_with_options(data_type: DataType, options: SortOptions) -> Self {
Self { options, data_type }
}
pub fn size(&self) -> usize {
self.data_type.size() + std::mem::size_of::<Self>()
- std::mem::size_of::<DataType>()
}
}
impl RowConverter {
pub fn new(fields: Vec<SortField>) -> Result<Self> {
if !Self::supports_fields(&fields) {
return Err(ArrowError::NotYetImplemented(format!(
"not yet implemented: {:?}",
fields
)));
}
let interners = (0..fields.len()).map(|_| None).collect();
Ok(Self {
fields: fields.into(),
interners,
})
}
pub fn supports_fields(fields: &[SortField]) -> bool {
fields.iter().all(|x| !DataType::is_nested(&x.data_type))
}
pub fn convert_columns(&mut self, columns: &[ArrayRef]) -> Result<Rows> {
if columns.len() != self.fields.len() {
return Err(ArrowError::InvalidArgumentError(format!(
"Incorrect number of arrays provided to RowConverter, expected {} got {}",
self.fields.len(),
columns.len()
)));
}
let dictionaries = columns
.iter()
.zip(&mut self.interners)
.zip(self.fields.iter())
.map(|((column, interner), field)| {
if !column.data_type().equals_datatype(&field.data_type) {
return Err(ArrowError::InvalidArgumentError(format!(
"RowConverter column schema mismatch, expected {} got {}",
field.data_type,
column.data_type()
)));
}
let values = downcast_dictionary_array! {
column => column.values(),
_ => return Ok(None)
};
let interner = interner.get_or_insert_with(Default::default);
let mapping: Vec<_> = compute_dictionary_mapping(interner, values)
.into_iter()
.map(|maybe_interned| {
maybe_interned.map(|interned| interner.normalized_key(interned))
})
.collect();
Ok(Some(mapping))
})
.collect::<Result<Vec<_>>>()?;
let config = RowConfig {
fields: Arc::clone(&self.fields),
validate_utf8: false,
};
let mut rows = new_empty_rows(columns, &dictionaries, config);
for ((column, field), dictionary) in
columns.iter().zip(self.fields.iter()).zip(dictionaries)
{
encode_column(&mut rows, column, field.options, dictionary.as_deref())
}
if cfg!(debug_assertions) {
assert_eq!(*rows.offsets.last().unwrap(), rows.buffer.len());
rows.offsets
.windows(2)
.for_each(|w| assert!(w[0] <= w[1], "offsets should be monotonic"));
}
Ok(rows)
}
pub fn convert_rows<'a, I>(&self, rows: I) -> Result<Vec<ArrayRef>>
where
I: IntoIterator<Item = Row<'a>>,
{
let mut validate_utf8 = false;
let mut rows: Vec<_> = rows
.into_iter()
.map(|row| {
assert!(
Arc::ptr_eq(&row.config.fields, &self.fields),
"rows were not produced by this RowConverter"
);
validate_utf8 |= row.config.validate_utf8;
row.data
})
.collect();
self.fields
.iter()
.zip(&self.interners)
.map(|(field, interner)| {
unsafe {
decode_column(field, &mut rows, interner.as_deref(), validate_utf8)
}
})
.collect()
}
pub fn parser(&self) -> RowParser {
RowParser::new(Arc::clone(&self.fields))
}
pub fn size(&self) -> usize {
std::mem::size_of::<Self>()
+ self.fields.iter().map(|x| x.size()).sum::<usize>()
+ self.interners.capacity()
* std::mem::size_of::<Option<Box<OrderPreservingInterner>>>()
+ self
.interners
.iter()
.filter_map(|x| x.as_ref().map(|x| x.size()))
.sum::<usize>()
}
}
#[derive(Debug)]
pub struct RowParser {
config: RowConfig,
}
impl RowParser {
fn new(fields: Arc<[SortField]>) -> Self {
Self {
config: RowConfig {
fields,
validate_utf8: true,
},
}
}
pub fn parse<'a>(&'a self, bytes: &'a [u8]) -> Row<'a> {
Row {
data: bytes,
config: &self.config,
}
}
}
#[derive(Debug, Clone)]
struct RowConfig {
fields: Arc<[SortField]>,
validate_utf8: bool,
}
#[derive(Debug)]
pub struct Rows {
buffer: Box<[u8]>,
offsets: Box<[usize]>,
config: RowConfig,
}
impl Rows {
pub fn row(&self, row: usize) -> Row<'_> {
let end = self.offsets[row + 1];
let start = self.offsets[row];
Row {
data: &self.buffer[start..end],
config: &self.config,
}
}
pub fn num_rows(&self) -> usize {
self.offsets.len() - 1
}
pub fn iter(&self) -> RowsIter<'_> {
self.into_iter()
}
pub fn size(&self) -> usize {
std::mem::size_of::<Self>()
+ self.buffer.len()
+ self.offsets.len() * std::mem::size_of::<usize>()
}
}
impl<'a> IntoIterator for &'a Rows {
type Item = Row<'a>;
type IntoIter = RowsIter<'a>;
fn into_iter(self) -> Self::IntoIter {
RowsIter {
rows: self,
start: 0,
end: self.num_rows(),
}
}
}
#[derive(Debug)]
pub struct RowsIter<'a> {
rows: &'a Rows,
start: usize,
end: usize,
}
impl<'a> Iterator for RowsIter<'a> {
type Item = Row<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.end == self.start {
return None;
}
let row = self.rows.row(self.start);
self.start += 1;
Some(row)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.len();
(len, Some(len))
}
}
impl<'a> ExactSizeIterator for RowsIter<'a> {
fn len(&self) -> usize {
self.end - self.start
}
}
impl<'a> DoubleEndedIterator for RowsIter<'a> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.end == self.start {
return None;
}
let row = self.rows.row(self.end);
self.end -= 1;
Some(row)
}
}
#[derive(Debug, Copy, Clone)]
pub struct Row<'a> {
data: &'a [u8],
config: &'a RowConfig,
}
impl<'a> Row<'a> {
pub fn owned(&self) -> OwnedRow {
OwnedRow {
data: self.data.to_vec(),
config: self.config.clone(),
}
}
}
impl<'a> PartialEq for Row<'a> {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.data.eq(other.data)
}
}
impl<'a> Eq for Row<'a> {}
impl<'a> PartialOrd for Row<'a> {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.data.partial_cmp(other.data)
}
}
impl<'a> Ord for Row<'a> {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
self.data.cmp(other.data)
}
}
impl<'a> Hash for Row<'a> {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.data.hash(state)
}
}
impl<'a> AsRef<[u8]> for Row<'a> {
#[inline]
fn as_ref(&self) -> &[u8] {
self.data
}
}
#[derive(Debug, Clone)]
pub struct OwnedRow {
data: Vec<u8>,
config: RowConfig,
}
impl OwnedRow {
pub fn row(&self) -> Row<'_> {
Row {
data: &self.data,
config: &self.config,
}
}
}
impl PartialEq for OwnedRow {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.row().eq(&other.row())
}
}
impl Eq for OwnedRow {}
impl PartialOrd for OwnedRow {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.row().partial_cmp(&other.row())
}
}
impl Ord for OwnedRow {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
self.row().cmp(&other.row())
}
}
impl Hash for OwnedRow {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.row().hash(state)
}
}
impl AsRef<[u8]> for OwnedRow {
#[inline]
fn as_ref(&self) -> &[u8] {
&self.data
}
}
#[inline]
fn null_sentinel(options: SortOptions) -> u8 {
match options.nulls_first {
true => 0,
false => 0xFF,
}
}
fn new_empty_rows(
cols: &[ArrayRef],
dictionaries: &[Option<Vec<Option<&[u8]>>>],
config: RowConfig,
) -> Rows {
use fixed::FixedLengthEncoding;
let num_rows = cols.first().map(|x| x.len()).unwrap_or(0);
let mut lengths = vec![0; num_rows];
for (array, dict) in cols.iter().zip(dictionaries) {
downcast_primitive_array! {
array => lengths.iter_mut().for_each(|x| *x += fixed::encoded_len(array)),
DataType::Null => {},
DataType::Boolean => lengths.iter_mut().for_each(|x| *x += bool::ENCODED_LEN),
DataType::Binary => as_generic_binary_array::<i32>(array)
.iter()
.zip(lengths.iter_mut())
.for_each(|(slice, length)| *length += variable::encoded_len(slice)),
DataType::LargeBinary => as_generic_binary_array::<i64>(array)
.iter()
.zip(lengths.iter_mut())
.for_each(|(slice, length)| *length += variable::encoded_len(slice)),
DataType::Utf8 => as_string_array(array)
.iter()
.zip(lengths.iter_mut())
.for_each(|(slice, length)| {
*length += variable::encoded_len(slice.map(|x| x.as_bytes()))
}),
DataType::LargeUtf8 => as_largestring_array(array)
.iter()
.zip(lengths.iter_mut())
.for_each(|(slice, length)| {
*length += variable::encoded_len(slice.map(|x| x.as_bytes()))
}),
DataType::Dictionary(_, _) => downcast_dictionary_array! {
array => {
let dict = dict.as_ref().unwrap();
for (v, length) in array.keys().iter().zip(lengths.iter_mut()) {
match v.and_then(|v| dict[v as usize]) {
Some(k) => *length += k.len() + 1,
None => *length += 1,
}
}
}
_ => unreachable!(),
}
_ => unreachable!(),
}
}
let mut offsets = Vec::with_capacity(num_rows + 1);
offsets.push(0);
let mut cur_offset = 0_usize;
for l in lengths {
offsets.push(cur_offset);
cur_offset = cur_offset.checked_add(l).expect("overflow");
}
let buffer = vec![0_u8; cur_offset];
Rows {
buffer: buffer.into(),
offsets: offsets.into(),
config,
}
}
fn encode_column(
out: &mut Rows,
column: &ArrayRef,
opts: SortOptions,
dictionary: Option<&[Option<&[u8]>]>,
) {
downcast_primitive_array! {
column => fixed::encode(out, column, opts),
DataType::Null => {}
DataType::Boolean => fixed::encode(out, as_boolean_array(column), opts),
DataType::Binary => {
variable::encode(out, as_generic_binary_array::<i32>(column).iter(), opts)
}
DataType::LargeBinary => {
variable::encode(out, as_generic_binary_array::<i64>(column).iter(), opts)
}
DataType::Utf8 => variable::encode(
out,
as_string_array(column).iter().map(|x| x.map(|x| x.as_bytes())),
opts,
),
DataType::LargeUtf8 => variable::encode(
out,
as_largestring_array(column)
.iter()
.map(|x| x.map(|x| x.as_bytes())),
opts,
),
DataType::Dictionary(_, _) => downcast_dictionary_array! {
column => encode_dictionary(out, column, dictionary.unwrap(), opts),
_ => unreachable!()
}
_ => unreachable!(),
}
}
macro_rules! decode_primitive_helper {
($t:ty, $rows:ident, $data_type:ident, $options:ident) => {
Arc::new(decode_primitive::<$t>($rows, $data_type, $options))
};
}
macro_rules! decode_dictionary_helper {
($t:ty, $interner:ident, $v:ident, $options:ident, $rows:ident) => {
Arc::new(decode_dictionary::<$t>(
$interner.unwrap(),
$v.as_ref(),
$options,
$rows,
)?)
};
}
unsafe fn decode_column(
field: &SortField,
rows: &mut [&[u8]],
interner: Option<&OrderPreservingInterner>,
validate_utf8: bool,
) -> Result<ArrayRef> {
let options = field.options;
let data_type = field.data_type.clone();
let array: ArrayRef = downcast_primitive! {
data_type => (decode_primitive_helper, rows, data_type, options),
DataType::Null => Arc::new(NullArray::new(rows.len())),
DataType::Boolean => Arc::new(decode_bool(rows, options)),
DataType::Binary => Arc::new(decode_binary::<i32>(rows, options)),
DataType::LargeBinary => Arc::new(decode_binary::<i64>(rows, options)),
DataType::Utf8 => Arc::new(decode_string::<i32>(rows, options, validate_utf8)),
DataType::LargeUtf8 => Arc::new(decode_string::<i64>(rows, options, validate_utf8)),
DataType::Dictionary(k, v) => downcast_integer! {
k.as_ref() => (decode_dictionary_helper, interner, v, options, rows),
_ => unreachable!()
},
_ => {
return Err(ArrowError::NotYetImplemented(format!(
"converting {} row is not supported",
field.data_type
)))
}
};
Ok(array)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use rand::distributions::uniform::SampleUniform;
use rand::distributions::{Distribution, Standard};
use rand::{thread_rng, Rng};
use arrow_array::NullArray;
use crate::array::{
BinaryArray, BooleanArray, DictionaryArray, Float32Array, GenericStringArray,
Int16Array, Int32Array, OffsetSizeTrait, PrimitiveArray,
PrimitiveDictionaryBuilder, StringArray,
};
use crate::compute::{LexicographicalComparator, SortColumn};
use crate::util::display::array_value_to_string;
use super::*;
#[test]
fn test_fixed_width() {
let cols = [
Arc::new(Int16Array::from_iter([
Some(1),
Some(2),
None,
Some(-5),
Some(2),
Some(2),
Some(0),
])) as ArrayRef,
Arc::new(Float32Array::from_iter([
Some(1.3),
Some(2.5),
None,
Some(4.),
Some(0.1),
Some(-4.),
Some(-0.),
])) as ArrayRef,
];
let mut converter = RowConverter::new(vec![
SortField::new(DataType::Int16),
SortField::new(DataType::Float32),
])
.unwrap();
let rows = converter.convert_columns(&cols).unwrap();
assert_eq!(rows.offsets.as_ref(), &[0, 8, 16, 24, 32, 40, 48, 56]);
assert_eq!(
rows.buffer.as_ref(),
&[
1, 128, 1, 1, 191, 166, 102, 102, 1, 128, 2, 1, 192, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 127, 251, 1, 192, 128, 0, 0, 1, 128, 2, 1, 189, 204, 204, 205, 1, 128, 2, 1, 63, 127, 255, 255, 1, 128, 0, 1, 127, 255, 255, 255 ]
);
assert!(rows.row(3) < rows.row(6));
assert!(rows.row(0) < rows.row(1));
assert!(rows.row(3) < rows.row(0));
assert!(rows.row(4) < rows.row(1));
assert!(rows.row(5) < rows.row(4));
let back = converter.convert_rows(&rows).unwrap();
for (expected, actual) in cols.iter().zip(&back) {
assert_eq!(expected, actual);
}
}
#[test]
fn test_decimal128() {
let mut converter = RowConverter::new(vec![SortField::new(
DataType::Decimal128(DECIMAL128_MAX_PRECISION, 7),
)])
.unwrap();
let col = Arc::new(
Decimal128Array::from_iter([
None,
Some(i128::MIN),
Some(-13),
Some(46_i128),
Some(5456_i128),
Some(i128::MAX),
])
.with_precision_and_scale(38, 7)
.unwrap(),
) as ArrayRef;
let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap();
for i in 0..rows.num_rows() - 1 {
assert!(rows.row(i) < rows.row(i + 1));
}
let back = converter.convert_rows(&rows).unwrap();
assert_eq!(back.len(), 1);
assert_eq!(col.as_ref(), back[0].as_ref())
}
#[test]
fn test_decimal256() {
let mut converter = RowConverter::new(vec![SortField::new(
DataType::Decimal256(DECIMAL256_MAX_PRECISION, 7),
)])
.unwrap();
let col = Arc::new(
Decimal256Array::from_iter([
None,
Some(i256::MIN),
Some(i256::from_parts(0, -1)),
Some(i256::from_parts(u128::MAX, -1)),
Some(i256::from_parts(u128::MAX, 0)),
Some(i256::from_parts(0, 46_i128)),
Some(i256::from_parts(5, 46_i128)),
Some(i256::MAX),
])
.with_precision_and_scale(DECIMAL256_MAX_PRECISION, 7)
.unwrap(),
) as ArrayRef;
let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap();
for i in 0..rows.num_rows() - 1 {
assert!(rows.row(i) < rows.row(i + 1));
}
let back = converter.convert_rows(&rows).unwrap();
assert_eq!(back.len(), 1);
assert_eq!(col.as_ref(), back[0].as_ref())
}
#[test]
fn test_bool() {
let mut converter =
RowConverter::new(vec![SortField::new(DataType::Boolean)]).unwrap();
let col = Arc::new(BooleanArray::from_iter([None, Some(false), Some(true)]))
as ArrayRef;
let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap();
assert!(rows.row(2) > rows.row(1));
assert!(rows.row(2) > rows.row(0));
assert!(rows.row(1) > rows.row(0));
let cols = converter.convert_rows(&rows).unwrap();
assert_eq!(&cols[0], &col);
let mut converter = RowConverter::new(vec![SortField::new_with_options(
DataType::Boolean,
SortOptions {
descending: true,
nulls_first: false,
},
)])
.unwrap();
let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap();
assert!(rows.row(2) < rows.row(1));
assert!(rows.row(2) < rows.row(0));
assert!(rows.row(1) < rows.row(0));
let cols = converter.convert_rows(&rows).unwrap();
assert_eq!(&cols[0], &col);
}
#[test]
fn test_timezone() {
let a = TimestampNanosecondArray::from(vec![1, 2, 3, 4, 5])
.with_timezone("+01:00".to_string());
let d = a.data_type().clone();
let mut converter =
RowConverter::new(vec![SortField::new(a.data_type().clone())]).unwrap();
let rows = converter.convert_columns(&[Arc::new(a) as _]).unwrap();
let back = converter.convert_rows(&rows).unwrap();
assert_eq!(back.len(), 1);
assert_eq!(back[0].data_type(), &d);
let mut a =
PrimitiveDictionaryBuilder::<Int32Type, TimestampNanosecondType>::new();
a.append(34).unwrap();
a.append_null();
a.append(345).unwrap();
let dict = a.finish();
let values = TimestampNanosecondArray::from(dict.values().data().clone());
let dict_with_tz = dict.with_values(&values.with_timezone("+02:00".to_string()));
let d = DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Timestamp(
TimeUnit::Nanosecond,
Some("+02:00".to_string()),
)),
);
assert_eq!(dict_with_tz.data_type(), &d);
let mut converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap();
let rows = converter
.convert_columns(&[Arc::new(dict_with_tz) as _])
.unwrap();
let back = converter.convert_rows(&rows).unwrap();
assert_eq!(back.len(), 1);
assert_eq!(back[0].data_type(), &d);
}
#[test]
fn test_null_encoding() {
let col = Arc::new(NullArray::new(10));
let mut converter =
RowConverter::new(vec![SortField::new(DataType::Null)]).unwrap();
let rows = converter.convert_columns(&[col]).unwrap();
assert_eq!(rows.num_rows(), 10);
assert_eq!(rows.row(1).data.len(), 0);
}
#[test]
fn test_variable_width() {
let col = Arc::new(StringArray::from_iter([
Some("hello"),
Some("he"),
None,
Some("foo"),
Some(""),
])) as ArrayRef;
let mut converter =
RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap();
let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap();
assert!(rows.row(1) < rows.row(0));
assert!(rows.row(2) < rows.row(4));
assert!(rows.row(3) < rows.row(0));
assert!(rows.row(3) < rows.row(1));
let cols = converter.convert_rows(&rows).unwrap();
assert_eq!(&cols[0], &col);
let col = Arc::new(BinaryArray::from_iter([
None,
Some(vec![0_u8; 0]),
Some(vec![0_u8; 6]),
Some(vec![0_u8; variable::BLOCK_SIZE]),
Some(vec![0_u8; variable::BLOCK_SIZE + 1]),
Some(vec![1_u8; 6]),
Some(vec![1_u8; variable::BLOCK_SIZE]),
Some(vec![1_u8; variable::BLOCK_SIZE + 1]),
Some(vec![0xFF_u8; 6]),
Some(vec![0xFF_u8; variable::BLOCK_SIZE]),
Some(vec![0xFF_u8; variable::BLOCK_SIZE + 1]),
])) as ArrayRef;
let mut converter =
RowConverter::new(vec![SortField::new(DataType::Binary)]).unwrap();
let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap();
for i in 0..rows.num_rows() {
for j in i + 1..rows.num_rows() {
assert!(
rows.row(i) < rows.row(j),
"{} < {} - {:?} < {:?}",
i,
j,
rows.row(i),
rows.row(j)
);
}
}
let cols = converter.convert_rows(&rows).unwrap();
assert_eq!(&cols[0], &col);
let mut converter = RowConverter::new(vec![SortField::new_with_options(
DataType::Binary,
SortOptions {
descending: true,
nulls_first: false,
},
)])
.unwrap();
let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap();
for i in 0..rows.num_rows() {
for j in i + 1..rows.num_rows() {
assert!(
rows.row(i) > rows.row(j),
"{} > {} - {:?} > {:?}",
i,
j,
rows.row(i),
rows.row(j)
);
}
}
let cols = converter.convert_rows(&rows).unwrap();
assert_eq!(&cols[0], &col);
}
#[test]
fn test_string_dictionary() {
let a = Arc::new(DictionaryArray::<Int32Type>::from_iter([
Some("foo"),
Some("hello"),
Some("he"),
None,
Some("hello"),
Some(""),
Some("hello"),
Some("hello"),
])) as ArrayRef;
let mut converter =
RowConverter::new(vec![SortField::new(a.data_type().clone())]).unwrap();
let rows_a = converter.convert_columns(&[Arc::clone(&a)]).unwrap();
assert!(rows_a.row(3) < rows_a.row(5));
assert!(rows_a.row(2) < rows_a.row(1));
assert!(rows_a.row(0) < rows_a.row(1));
assert!(rows_a.row(3) < rows_a.row(0));
assert_eq!(rows_a.row(1), rows_a.row(4));
assert_eq!(rows_a.row(1), rows_a.row(6));
assert_eq!(rows_a.row(1), rows_a.row(7));
let cols = converter.convert_rows(&rows_a).unwrap();
assert_eq!(&cols[0], &a);
let b = Arc::new(DictionaryArray::<Int32Type>::from_iter([
Some("hello"),
None,
Some("cupcakes"),
])) as ArrayRef;
let rows_b = converter.convert_columns(&[Arc::clone(&b)]).unwrap();
assert_eq!(rows_a.row(1), rows_b.row(0));
assert_eq!(rows_a.row(3), rows_b.row(1));
assert!(rows_b.row(2) < rows_a.row(0));
let cols = converter.convert_rows(&rows_b).unwrap();
assert_eq!(&cols[0], &b);
let mut converter = RowConverter::new(vec![SortField::new_with_options(
a.data_type().clone(),
SortOptions {
descending: true,
nulls_first: false,
},
)])
.unwrap();
let rows_c = converter.convert_columns(&[Arc::clone(&a)]).unwrap();
assert!(rows_c.row(3) > rows_c.row(5));
assert!(rows_c.row(2) > rows_c.row(1));
assert!(rows_c.row(0) > rows_c.row(1));
assert!(rows_c.row(3) > rows_c.row(0));
let cols = converter.convert_rows(&rows_c).unwrap();
assert_eq!(&cols[0], &a);
}
#[test]
fn test_primitive_dictionary() {
let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new();
builder.append(2).unwrap();
builder.append(3).unwrap();
builder.append(0).unwrap();
builder.append_null();
builder.append(5).unwrap();
builder.append(3).unwrap();
builder.append(-1).unwrap();
let a = builder.finish();
let mut converter =
RowConverter::new(vec![SortField::new(a.data_type().clone())]).unwrap();
let rows = converter.convert_columns(&[Arc::new(a)]).unwrap();
assert!(rows.row(0) < rows.row(1));
assert!(rows.row(2) < rows.row(0));
assert!(rows.row(3) < rows.row(2));
assert!(rows.row(6) < rows.row(2));
assert!(rows.row(3) < rows.row(6));
}
#[test]
fn test_dictionary_nulls() {
let values =
Int32Array::from_iter([Some(1), Some(-1), None, Some(4), None]).into_data();
let keys =
Int32Array::from_iter([Some(0), Some(0), Some(1), Some(2), Some(4), None])
.into_data();
let data_type =
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Int32));
let data = keys
.into_builder()
.data_type(data_type.clone())
.child_data(vec![values])
.build()
.unwrap();
let mut converter = RowConverter::new(vec![SortField::new(data_type)]).unwrap();
let rows = converter
.convert_columns(&[Arc::new(DictionaryArray::<Int32Type>::from(data))])
.unwrap();
assert_eq!(rows.row(0), rows.row(1));
assert_eq!(rows.row(3), rows.row(4));
assert_eq!(rows.row(4), rows.row(5));
assert!(rows.row(3) < rows.row(0));
}
#[test]
#[should_panic(expected = "Invalid UTF-8 sequence")]
fn test_invalid_utf8() {
let mut converter =
RowConverter::new(vec![SortField::new(DataType::Binary)]).unwrap();
let array = Arc::new(BinaryArray::from_iter_values([&[0xFF]])) as _;
let rows = converter.convert_columns(&[array]).unwrap();
let binary_row = rows.row(0);
let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap();
let parser = converter.parser();
let utf8_row = parser.parse(binary_row.as_ref());
converter.convert_rows(std::iter::once(utf8_row)).unwrap();
}
#[test]
#[should_panic(expected = "rows were not produced by this RowConverter")]
fn test_different_converter() {
let values = Arc::new(Int32Array::from_iter([Some(1), Some(-1)]));
let mut converter =
RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap();
let rows = converter.convert_columns(&[values]).unwrap();
let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap();
let _ = converter.convert_rows(&rows);
}
fn generate_primitive_array<K>(len: usize, valid_percent: f64) -> PrimitiveArray<K>
where
K: ArrowPrimitiveType,
Standard: Distribution<K::Native>,
{
let mut rng = thread_rng();
(0..len)
.map(|_| rng.gen_bool(valid_percent).then(|| rng.gen()))
.collect()
}
fn generate_strings<O: OffsetSizeTrait>(
len: usize,
valid_percent: f64,
) -> GenericStringArray<O> {
let mut rng = thread_rng();
(0..len)
.map(|_| {
rng.gen_bool(valid_percent).then(|| {
let len = rng.gen_range(0..100);
let bytes = (0..len).map(|_| rng.gen_range(0..128)).collect();
String::from_utf8(bytes).unwrap()
})
})
.collect()
}
fn generate_dictionary<K>(
values: ArrayRef,
len: usize,
valid_percent: f64,
) -> DictionaryArray<K>
where
K: ArrowDictionaryKeyType,
K::Native: SampleUniform,
{
let mut rng = thread_rng();
let min_key = K::Native::from_usize(0).unwrap();
let max_key = K::Native::from_usize(values.len()).unwrap();
let keys: PrimitiveArray<K> = (0..len)
.map(|_| {
rng.gen_bool(valid_percent)
.then(|| rng.gen_range(min_key..max_key))
})
.collect();
let data_type = DataType::Dictionary(
Box::new(K::DATA_TYPE),
Box::new(values.data_type().clone()),
);
let data = keys
.into_data()
.into_builder()
.data_type(data_type)
.add_child_data(values.data().clone())
.build()
.unwrap();
DictionaryArray::from(data)
}
fn generate_column(len: usize) -> ArrayRef {
let mut rng = thread_rng();
match rng.gen_range(0..9) {
0 => Arc::new(generate_primitive_array::<Int32Type>(len, 0.8)),
1 => Arc::new(generate_primitive_array::<UInt32Type>(len, 0.8)),
2 => Arc::new(generate_primitive_array::<Int64Type>(len, 0.8)),
3 => Arc::new(generate_primitive_array::<UInt64Type>(len, 0.8)),
4 => Arc::new(generate_primitive_array::<Float32Type>(len, 0.8)),
5 => Arc::new(generate_primitive_array::<Float64Type>(len, 0.8)),
6 => Arc::new(generate_strings::<i32>(len, 0.8)),
7 => Arc::new(generate_dictionary::<Int64Type>(
Arc::new(generate_strings::<i32>(rng.gen_range(1..len), 1.0)),
len,
0.8,
)),
8 => Arc::new(generate_dictionary::<Int64Type>(
Arc::new(generate_primitive_array::<Int64Type>(
rng.gen_range(1..len),
1.0,
)),
len,
0.8,
)),
_ => unreachable!(),
}
}
fn print_row(cols: &[SortColumn], row: usize) -> String {
let t: Vec<_> = cols
.iter()
.map(|x| array_value_to_string(&x.values, row).unwrap())
.collect();
t.join(",")
}
fn print_col_types(cols: &[SortColumn]) -> String {
let t: Vec<_> = cols
.iter()
.map(|x| x.values.data_type().to_string())
.collect();
t.join(",")
}
#[test]
#[cfg_attr(miri, ignore)]
fn fuzz_test() {
for _ in 0..100 {
let mut rng = thread_rng();
let num_columns = rng.gen_range(1..5);
let len = rng.gen_range(5..100);
let arrays: Vec<_> = (0..num_columns).map(|_| generate_column(len)).collect();
let options: Vec<_> = (0..num_columns)
.map(|_| SortOptions {
descending: rng.gen_bool(0.5),
nulls_first: rng.gen_bool(0.5),
})
.collect();
let sort_columns: Vec<_> = options
.iter()
.zip(&arrays)
.map(|(o, c)| SortColumn {
values: Arc::clone(c),
options: Some(*o),
})
.collect();
let comparator = LexicographicalComparator::try_new(&sort_columns).unwrap();
let columns = options
.into_iter()
.zip(&arrays)
.map(|(o, a)| SortField::new_with_options(a.data_type().clone(), o))
.collect();
let mut converter = RowConverter::new(columns).unwrap();
let rows = converter.convert_columns(&arrays).unwrap();
for i in 0..len {
for j in 0..len {
let row_i = rows.row(i);
let row_j = rows.row(j);
let row_cmp = row_i.cmp(&row_j);
let lex_cmp = comparator.compare(&i, &j);
assert_eq!(
row_cmp,
lex_cmp,
"({:?} vs {:?}) vs ({:?} vs {:?}) for types {}",
print_row(&sort_columns, i),
print_row(&sort_columns, j),
row_i,
row_j,
print_col_types(&sort_columns)
);
}
}
let back = converter.convert_rows(&rows).unwrap();
for (actual, expected) in back.iter().zip(&arrays) {
assert_eq!(actual, expected)
}
}
}
}