use std::sync::Arc;
use arrow::{
array::{Float64Array, Int64Array, StringArray, UInt64Array},
compute::{self, SortOptions, TakeOptions},
datatypes::Schema,
record_batch::RecordBatch,
};
use criterion::{criterion_group, criterion_main, Criterion};
use datafusion::{
execution::context::TaskContext,
physical_plan::{
memory::MemoryExec, sorts::sort_preserving_merge::SortPreservingMergeExec,
ExecutionPlan,
},
prelude::SessionContext,
};
use datafusion_physical_expr::{expressions::col, PhysicalSortExpr};
use futures::StreamExt;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use tokio::runtime::Runtime;
use lazy_static::lazy_static;
const NUM_STREAMS: u64 = 8;
const INPUT_SIZE: u64 = 100000;
lazy_static! {
static ref I64_STREAMS: Vec<Vec<RecordBatch>> = i64_streams();
static ref F64_STREAMS: Vec<Vec<RecordBatch>> = f64_streams();
static ref UTF8_LOW_CARDINALITY_STREAMS: Vec<Vec<RecordBatch>> = utf8_low_cardinality_streams();
static ref UTF8_HIGH_CARDINALITY_STREAMS: Vec<Vec<RecordBatch>> = utf8_high_cardinality_streams();
static ref UTF8_TUPLE_STREAMS: Vec<Vec<RecordBatch>> = utf8_tuple_streams();
static ref MIXED_TUPLE_STREAMS: Vec<Vec<RecordBatch>> = mixed_tuple_streams();
}
fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("merge i64", |b| {
let case = MergeBenchCase::new(&I64_STREAMS);
b.iter(move || case.run())
});
c.bench_function("merge f64", |b| {
let case = MergeBenchCase::new(&F64_STREAMS);
b.iter(move || case.run())
});
c.bench_function("merge utf8 low cardinality", |b| {
let case = MergeBenchCase::new(&UTF8_LOW_CARDINALITY_STREAMS);
b.iter(move || case.run())
});
c.bench_function("merge utf8 high cardinality", |b| {
let case = MergeBenchCase::new(&UTF8_HIGH_CARDINALITY_STREAMS);
b.iter(move || case.run())
});
c.bench_function("merge utf8 tuple", |b| {
let case = MergeBenchCase::new(&UTF8_TUPLE_STREAMS);
b.iter(move || case.run())
});
c.bench_function("merge mixed tuple", |b| {
let case = MergeBenchCase::new(&MIXED_TUPLE_STREAMS);
b.iter(move || case.run())
});
}
struct MergeBenchCase {
runtime: Runtime,
task_ctx: Arc<TaskContext>,
plan: Arc<dyn ExecutionPlan>,
}
impl MergeBenchCase {
fn new(partitions: &[Vec<RecordBatch>]) -> Self {
let runtime = tokio::runtime::Builder::new_multi_thread().build().unwrap();
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let schema = partitions[0][0].schema();
let sort = make_sort_exprs(schema.as_ref());
let projection = None;
let exec = MemoryExec::try_new(partitions, schema, projection).unwrap();
let plan = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec)));
Self {
runtime,
task_ctx,
plan,
}
}
fn run(&self) {
let plan = Arc::clone(&self.plan);
let task_ctx = Arc::clone(&self.task_ctx);
assert_eq!(plan.output_partitioning().partition_count(), 1);
self.runtime.block_on(async move {
let mut stream = plan.execute(0, task_ctx).unwrap();
while let Some(b) = stream.next().await {
b.expect("unexpected execution error");
}
})
}
}
fn make_sort_exprs(schema: &Schema) -> Vec<PhysicalSortExpr> {
schema
.fields()
.iter()
.map(|f| PhysicalSortExpr {
expr: col(f.name(), schema).unwrap(),
options: SortOptions::default(),
})
.collect()
}
fn i64_streams() -> Vec<Vec<RecordBatch>> {
let array: Int64Array = DataGenerator::new().i64_values().into_iter().collect();
let batch = RecordBatch::try_from_iter(vec![("i64", Arc::new(array) as _)]).unwrap();
split_batch(batch)
}
fn f64_streams() -> Vec<Vec<RecordBatch>> {
let array: Float64Array = DataGenerator::new().f64_values().into_iter().collect();
let batch = RecordBatch::try_from_iter(vec![("f64", Arc::new(array) as _)]).unwrap();
split_batch(batch)
}
fn utf8_low_cardinality_streams() -> Vec<Vec<RecordBatch>> {
let array: StringArray = DataGenerator::new()
.utf8_low_cardinality_values()
.into_iter()
.collect();
let batch =
RecordBatch::try_from_iter(vec![("utf_low", Arc::new(array) as _)]).unwrap();
split_batch(batch)
}
fn utf8_high_cardinality_streams() -> Vec<Vec<RecordBatch>> {
let array: StringArray = DataGenerator::new()
.utf8_high_cardinality_values()
.into_iter()
.collect();
let batch =
RecordBatch::try_from_iter(vec![("utf_high", Arc::new(array) as _)]).unwrap();
split_batch(batch)
}
fn utf8_tuple_streams() -> Vec<Vec<RecordBatch>> {
let mut gen = DataGenerator::new();
let mut tuples: Vec<_> = gen
.utf8_low_cardinality_values()
.into_iter()
.zip(gen.utf8_low_cardinality_values().into_iter())
.zip(gen.utf8_high_cardinality_values().into_iter())
.collect();
tuples.sort_unstable();
let (tuples, utf8_high): (Vec<_>, Vec<_>) = tuples.into_iter().unzip();
let (utf8_low1, utf8_low2): (Vec<_>, Vec<_>) = tuples.into_iter().unzip();
let utf8_high: StringArray = utf8_high.into_iter().collect();
let utf8_low1: StringArray = utf8_low1.into_iter().collect();
let utf8_low2: StringArray = utf8_low2.into_iter().collect();
let batch = RecordBatch::try_from_iter(vec![
("utf_low1", Arc::new(utf8_low1) as _),
("utf_low2", Arc::new(utf8_low2) as _),
("utf_high", Arc::new(utf8_high) as _),
])
.unwrap();
split_batch(batch)
}
fn mixed_tuple_streams() -> Vec<Vec<RecordBatch>> {
let mut gen = DataGenerator::new();
let mut tuples: Vec<_> = gen
.i64_values()
.into_iter()
.zip(gen.utf8_low_cardinality_values().into_iter())
.zip(gen.utf8_low_cardinality_values().into_iter())
.zip(gen.i64_values().into_iter())
.collect();
tuples.sort_unstable();
let (tuples, i64_values): (Vec<_>, Vec<_>) = tuples.into_iter().unzip();
let (tuples, utf8_low2): (Vec<_>, Vec<_>) = tuples.into_iter().unzip();
let (f64_values, utf8_low1): (Vec<_>, Vec<_>) = tuples.into_iter().unzip();
let f64_values: Float64Array = f64_values.into_iter().map(|v| v as f64).collect();
let utf8_low1: StringArray = utf8_low1.into_iter().collect();
let utf8_low2: StringArray = utf8_low2.into_iter().collect();
let i64_values: Int64Array = i64_values.into_iter().collect();
let batch = RecordBatch::try_from_iter(vec![
("f64", Arc::new(f64_values) as _),
("utf_low1", Arc::new(utf8_low1) as _),
("utf_low2", Arc::new(utf8_low2) as _),
("i64", Arc::new(i64_values) as _),
])
.unwrap();
split_batch(batch)
}
struct DataGenerator {
rng: StdRng,
}
impl DataGenerator {
fn new() -> Self {
Self {
rng: StdRng::seed_from_u64(42),
}
}
fn i64_values(&mut self) -> Vec<i64> {
let mut vec: Vec<_> = (0..INPUT_SIZE)
.map(|_| self.rng.gen_range(0..INPUT_SIZE as i64))
.collect();
vec.sort_unstable();
vec
}
fn f64_values(&mut self) -> Vec<f64> {
self.i64_values().into_iter().map(|v| v as f64).collect()
}
fn utf8_low_cardinality_values(&mut self) -> Vec<Option<Arc<str>>> {
let strings = (0..100).map(|s| format!("value{}", s)).collect::<Vec<_>>();
let mut input = (0..INPUT_SIZE)
.map(|_| {
let idx = self.rng.gen_range(0..strings.len());
let s = Arc::from(strings[idx].as_str());
Some(s)
})
.collect::<Vec<_>>();
input.sort_unstable();
input
}
fn utf8_high_cardinality_values(&mut self) -> Vec<Option<String>> {
let mut input = (0..INPUT_SIZE)
.map(|_| Some(self.random_string()))
.collect::<Vec<_>>();
input.sort_unstable();
input
}
fn random_string(&mut self) -> String {
let rng = &mut self.rng;
rng.sample_iter(rand::distributions::Alphanumeric)
.filter(|c| c.is_ascii_alphabetic())
.take(20)
.map(char::from)
.collect::<String>()
}
}
fn split_batch(input_batch: RecordBatch) -> Vec<Vec<RecordBatch>> {
let mut rng = StdRng::seed_from_u64(1337);
let stream_assignments = (0..input_batch.num_rows())
.map(|_| rng.gen_range(0..NUM_STREAMS))
.collect();
(0..NUM_STREAMS)
.map(|stream| {
vec![take_columns(&input_batch, &stream_assignments, stream)]
})
.collect::<Vec<_>>()
}
fn take_columns(
input_batch: &RecordBatch,
stream_assignments: &UInt64Array,
stream: u64,
) -> RecordBatch {
let stream_indices: UInt64Array = stream_assignments
.iter()
.enumerate()
.filter_map(|(idx, stream_idx)| {
if stream_idx.unwrap() == stream {
Some(idx as u64)
} else {
None
}
})
.collect();
let options = Some(TakeOptions { check_bounds: true });
let new_columns = input_batch
.columns()
.iter()
.map(|array| compute::take(array, &stream_indices, options.clone()).unwrap())
.collect();
RecordBatch::try_new(input_batch.schema(), new_columns).unwrap()
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);