use std::sync::Arc;
use log::{debug, error};
use crate::error::Result;
use crate::execution::context::TaskContext;
use crate::physical_plan::ExecutionPlan;
use plan::{PipelinePlan, PipelinePlanner, RoutablePipeline};
use task::{spawn_plan, Task};
use rayon::{ThreadPool, ThreadPoolBuilder};
pub use task::ExecutionResults;
mod pipeline;
mod plan;
mod task;
#[derive(Debug)]
pub struct SchedulerBuilder {
inner: ThreadPoolBuilder,
}
impl SchedulerBuilder {
pub fn new(num_threads: usize) -> Self {
let builder = ThreadPoolBuilder::new()
.num_threads(num_threads)
.panic_handler(|p| error!("{}", format_worker_panic(p)))
.thread_name(|idx| format!("df-worker-{idx}"));
Self { inner: builder }
}
#[cfg(test)]
fn panic_handler<H>(self, panic_handler: H) -> Self
where
H: Fn(Box<dyn std::any::Any + Send>) + Send + Sync + 'static,
{
Self {
inner: self.inner.panic_handler(panic_handler),
}
}
fn build(self) -> Scheduler {
Scheduler {
pool: Arc::new(self.inner.build().unwrap()),
}
}
}
pub struct Scheduler {
pool: Arc<ThreadPool>,
}
impl Scheduler {
pub fn new(num_threads: usize) -> Self {
SchedulerBuilder::new(num_threads).build()
}
pub fn schedule(
&self,
plan: Arc<dyn ExecutionPlan>,
context: Arc<TaskContext>,
) -> Result<ExecutionResults> {
let plan = PipelinePlanner::new(plan, context).build()?;
Ok(self.schedule_plan(plan))
}
pub(crate) fn schedule_plan(&self, plan: PipelinePlan) -> ExecutionResults {
spawn_plan(plan, self.spawner())
}
fn spawner(&self) -> Spawner {
Spawner {
pool: self.pool.clone(),
}
}
}
fn format_worker_panic(panic: Box<dyn std::any::Any + Send>) -> String {
let maybe_idx = rayon::current_thread_index();
let worker: &dyn std::fmt::Display = match &maybe_idx {
Some(idx) => idx,
None => &"UNKNOWN",
};
let message = if let Some(msg) = panic.downcast_ref::<&str>() {
*msg
} else if let Some(msg) = panic.downcast_ref::<String>() {
msg.as_str()
} else {
"UNKNOWN"
};
format!("worker {worker} panicked with: {message}")
}
fn is_worker() -> bool {
rayon::current_thread_index().is_some()
}
fn spawn_local(task: Task) {
assert!(is_worker(), "must be called from a worker");
rayon::spawn(|| task.do_work())
}
fn spawn_local_fifo(task: Task) {
assert!(is_worker(), "must be called from a worker");
rayon::spawn_fifo(|| task.do_work())
}
#[derive(Debug, Clone)]
pub(crate) struct Spawner {
pool: Arc<ThreadPool>,
}
impl Spawner {
fn spawn(&self, task: Task) {
debug!("Spawning {:?} to any worker", task);
self.pool.spawn(move || task.do_work());
}
}
#[cfg(test)]
mod tests {
use arrow::util::pretty::pretty_format_batches;
use std::ops::Range;
use std::panic::panic_any;
use futures::{StreamExt, TryStreamExt};
use log::info;
use rand::distributions::uniform::SampleUniform;
use rand::{thread_rng, Rng};
use crate::arrow::array::{ArrayRef, PrimitiveArray};
use crate::arrow::datatypes::{ArrowPrimitiveType, Float64Type, Int32Type};
use crate::arrow::record_batch::RecordBatch;
use crate::datasource::{MemTable, TableProvider};
use crate::physical_plan::displayable;
use crate::prelude::{SessionConfig, SessionContext};
use super::*;
fn generate_primitive<T, R>(
rng: &mut R,
len: usize,
valid_percent: f64,
range: Range<T::Native>,
) -> ArrayRef
where
T: ArrowPrimitiveType,
T::Native: SampleUniform,
R: Rng,
{
Arc::new(PrimitiveArray::<T>::from_iter((0..len).map(|_| {
rng.gen_bool(valid_percent)
.then(|| rng.gen_range(range.clone()))
})))
}
fn generate_batch<R: Rng>(
rng: &mut R,
row_count: usize,
id_offset: i32,
) -> RecordBatch {
let id_range = id_offset..(row_count as i32 + id_offset);
let a = generate_primitive::<Int32Type, _>(rng, row_count, 0.5, 0..1000);
let b = generate_primitive::<Float64Type, _>(rng, row_count, 0.5, 0. ..1000.);
let id = PrimitiveArray::<Int32Type>::from_iter_values(id_range);
RecordBatch::try_from_iter_with_nullable([
("a", a, true),
("b", b, true),
("id", Arc::new(id), false),
])
.unwrap()
}
const BATCHES_PER_PARTITION: usize = 20;
const ROWS_PER_BATCH: usize = 100;
const NUM_PARTITIONS: usize = 2;
fn make_batches() -> Vec<Vec<RecordBatch>> {
let mut rng = thread_rng();
let mut id_offset = 0;
(0..NUM_PARTITIONS)
.map(|_| {
(0..BATCHES_PER_PARTITION)
.map(|_| {
let batch = generate_batch(&mut rng, ROWS_PER_BATCH, id_offset);
id_offset += ROWS_PER_BATCH as i32;
batch
})
.collect()
})
.collect()
}
fn make_provider() -> Arc<dyn TableProvider> {
let batches = make_batches();
let schema = batches.first().unwrap().first().unwrap().schema();
Arc::new(MemTable::try_new(schema, make_batches()).unwrap())
}
fn init_logging() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[tokio::test]
async fn test_simple() {
init_logging();
let scheduler = SchedulerBuilder::new(4)
.panic_handler(|panic| {
unreachable!("not expect panic: {:?}", panic);
})
.build();
let config = SessionConfig::new().with_target_partitions(4);
let context = SessionContext::with_config(config);
context.register_table("table1", make_provider()).unwrap();
context.register_table("table2", make_provider()).unwrap();
let queries = [
"select * from table1 order by id",
"select * from table1 where table1.a > 100 order by id",
"select distinct a from table1 where table1.b > 100 order by a",
"select * from table1 join table2 on table1.id = table2.id order by table1.id",
"select id from table1 union all select id from table2 order by id",
"select id from table1 union all select id from table2 where a > 100 order by id",
"select id, b from (select id, b from table1 union all select id, b from table2 where a > 100 order by id) as t where b > 10 order by id, b",
"select id, MIN(b), MAX(b), AVG(b) from table1 group by id order by id",
"select count(*) from table1 where table1.a > 4",
"WITH gp AS (SELECT id FROM table1 GROUP BY id)
SELECT COUNT(CAST(CAST(gp.id || 'xx' AS TIMESTAMP) AS BIGINT)) FROM gp",
];
for sql in queries {
let task = context.task_ctx();
let query = context.sql(sql).await.unwrap();
let plan = query.clone().create_physical_plan().await.unwrap();
info!("Plan: {}", displayable(plan.as_ref()).indent());
let stream = scheduler.schedule(plan, task).unwrap().stream();
let scheduled: Vec<_> = stream.try_collect().await.unwrap_or_default();
let expected = query.collect().await.unwrap_or_default();
let total_expected = expected.iter().map(|x| x.num_rows()).sum::<usize>();
let total_scheduled = scheduled.iter().map(|x| x.num_rows()).sum::<usize>();
assert_eq!(total_expected, total_scheduled);
info!("Query \"{}\" produced {} rows", sql, total_expected);
let expected = pretty_format_batches(&expected).unwrap().to_string();
let scheduled = pretty_format_batches(&scheduled).unwrap().to_string();
assert_eq!(
expected, scheduled,
"\n\nexpected:\n\n{expected}\nactual:\n\n{scheduled}\n\n"
);
}
}
#[tokio::test]
async fn test_partitioned() {
init_logging();
let scheduler = Scheduler::new(4);
let config = SessionConfig::new().with_target_partitions(4);
let context = SessionContext::with_config(config);
let plan = context
.read_table(make_provider())
.unwrap()
.create_physical_plan()
.await
.unwrap();
assert_eq!(plan.output_partitioning().partition_count(), NUM_PARTITIONS);
let results = scheduler
.schedule(plan.clone(), context.task_ctx())
.unwrap();
let batches = results.stream().try_collect::<Vec<_>>().await.unwrap();
assert_eq!(batches.len(), NUM_PARTITIONS * BATCHES_PER_PARTITION);
for batch in batches {
assert_eq!(batch.num_rows(), ROWS_PER_BATCH)
}
let results = scheduler.schedule(plan, context.task_ctx()).unwrap();
let streams = results.stream_partitioned();
let partitions: Vec<Vec<_>> =
futures::future::try_join_all(streams.into_iter().map(|s| s.try_collect()))
.await
.unwrap();
assert_eq!(partitions.len(), NUM_PARTITIONS);
for batches in partitions {
assert_eq!(batches.len(), BATCHES_PER_PARTITION);
for batch in batches {
assert_eq!(batch.num_rows(), ROWS_PER_BATCH);
}
}
}
#[tokio::test]
async fn test_panic() {
init_logging();
let do_test = |scheduler: Scheduler| {
scheduler.pool.spawn(|| panic!("test"));
scheduler.pool.spawn(|| panic!("{}", 1));
scheduler.pool.spawn(|| panic_any(21));
};
do_test(Scheduler::new(1));
let (sender, receiver) = futures::channel::mpsc::unbounded();
let scheduler = SchedulerBuilder::new(1)
.panic_handler(move |panic| {
let _ = sender.unbounded_send(format_worker_panic(panic));
})
.build();
do_test(scheduler);
let mut buffer: Vec<_> = receiver.collect().await;
buffer.sort_unstable();
assert_eq!(buffer.len(), 3);
assert_eq!(buffer[0], "worker 0 panicked with: 1");
assert_eq!(buffer[1], "worker 0 panicked with: UNKNOWN");
assert_eq!(buffer[2], "worker 0 panicked with: test");
}
}