use crate::execution_engine::DefaultExecutionEngine;
use crate::execution_engine::ExecutionEngine;
use crate::execution_engine::QueryStageExecutor;
use crate::metrics::ExecutorMetricsCollector;
use kapot_core::error::KapotError;
use kapot_core::serde::protobuf;
use kapot_core::serde::protobuf::ExecutorRegistration;
use kapot_core::serde::scheduler::PartitionId;
use dashmap::DashMap;
use datafusion::execution::context::TaskContext;
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::functions::all_default_functions;
use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
use futures::future::AbortHandle;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
pub struct TasksDrainedFuture(pub Arc<Executor>);
impl Future for TasksDrainedFuture {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.0.abort_handles.len() > 0 {
Poll::Pending
} else {
Poll::Ready(())
}
}
}
type AbortHandles = Arc<DashMap<(usize, PartitionId), AbortHandle>>;
#[derive(Clone)]
pub struct Executor {
pub metadata: ExecutorRegistration,
pub work_dir: String,
pub scalar_functions: HashMap<String, Arc<ScalarUDF>>,
pub aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
pub window_functions: HashMap<String, Arc<WindowUDF>>,
runtime: Arc<RuntimeEnv>,
runtime_with_data_cache: Option<Arc<RuntimeEnv>>,
pub metrics_collector: Arc<dyn ExecutorMetricsCollector>,
pub concurrent_tasks: usize,
abort_handles: AbortHandles,
pub(crate) execution_engine: Arc<dyn ExecutionEngine>,
}
impl Executor {
pub fn new(
metadata: ExecutorRegistration,
work_dir: &str,
runtime: Arc<RuntimeEnv>,
runtime_with_data_cache: Option<Arc<RuntimeEnv>>,
metrics_collector: Arc<dyn ExecutorMetricsCollector>,
concurrent_tasks: usize,
execution_engine: Option<Arc<dyn ExecutionEngine>>,
) -> Self {
let scalar_functions = all_default_functions()
.into_iter()
.map(|f| (f.name().to_string(), f))
.collect();
let aggregate_functions = all_default_aggregate_functions()
.into_iter()
.map(|f| (f.name().to_string(), f))
.collect();
Self {
metadata,
work_dir: work_dir.to_owned(),
scalar_functions,
aggregate_functions,
window_functions: HashMap::new(),
runtime,
runtime_with_data_cache,
metrics_collector,
concurrent_tasks,
abort_handles: Default::default(),
execution_engine: execution_engine
.unwrap_or_else(|| Arc::new(DefaultExecutionEngine {})),
}
}
}
impl Executor {
pub fn get_runtime(&self, data_cache: bool) -> Arc<RuntimeEnv> {
if data_cache {
if let Some(runtime) = self.runtime_with_data_cache.clone() {
runtime
} else {
self.runtime.clone()
}
} else {
self.runtime.clone()
}
}
pub async fn execute_query_stage(
&self,
task_id: usize,
partition: PartitionId,
query_stage_exec: Arc<dyn QueryStageExecutor>,
task_ctx: Arc<TaskContext>,
) -> Result<Vec<protobuf::ShuffleWritePartition>, KapotError> {
let (task, abort_handle) = futures::future::abortable(
query_stage_exec.execute_query_stage(partition.partition_id, task_ctx),
);
self.abort_handles
.insert((task_id, partition.clone()), abort_handle);
let partitions = task.await??;
self.abort_handles.remove(&(task_id, partition.clone()));
self.metrics_collector.record_stage(
&partition.job_id,
partition.stage_id,
partition.partition_id,
query_stage_exec,
);
Ok(partitions)
}
pub async fn cancel_task(
&self,
task_id: usize,
job_id: String,
stage_id: usize,
partition_id: usize,
) -> Result<bool, KapotError> {
if let Some((_, handle)) = self.abort_handles.remove(&(
task_id,
PartitionId {
job_id,
stage_id,
partition_id,
},
)) {
handle.abort();
Ok(true)
} else {
Ok(false)
}
}
pub fn work_dir(&self) -> &str {
&self.work_dir
}
pub fn active_task_count(&self) -> usize {
self.abort_handles.len()
}
}
#[cfg(test)]
mod test {
use crate::execution_engine::DefaultQueryStageExec;
use crate::executor::Executor;
use crate::metrics::LoggingMetricsCollector;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion::physical_expr::EquivalenceProperties;
use datafusion::physical_plan::execution_plan::Boundedness;
use kapot_core::execution_plans::ShuffleWriterExec;
use kapot_core::serde::protobuf::ExecutorRegistration;
use kapot_core::serde::scheduler::PartitionId;
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::context::TaskContext;
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
RecordBatchStream, SendableRecordBatchStream, Statistics,
};
use datafusion::prelude::SessionContext;
use futures::Stream;
use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tempfile::TempDir;
struct NeverendingRecordBatchStream;
impl RecordBatchStream for NeverendingRecordBatchStream {
fn schema(&self) -> SchemaRef {
Arc::new(Schema::empty())
}
}
impl Stream for NeverendingRecordBatchStream {
type Item = Result<RecordBatch, DataFusionError>;
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
Poll::Pending
}
}
#[derive(Debug)]
pub struct NeverendingOperator {
properties: PlanProperties,
}
impl NeverendingOperator {
fn new() -> Self {
let equivalence_properties = EquivalenceProperties::new(Arc::new(
Schema::empty(),
));
NeverendingOperator {
properties: PlanProperties::new(
equivalence_properties,
Partitioning::UnknownPartitioning(1),
datafusion::physical_plan::execution_plan::EmissionType::Both,
Boundedness::Bounded,
),
}
}
}
impl DisplayAs for NeverendingOperator {
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "NeverendingOperator")
}
}
}
}
impl ExecutionPlan for NeverendingOperator {
fn name(&self) -> &str {
"NeverendingOperator"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::new(Schema::empty())
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
Ok(self)
}
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
) -> datafusion::common::Result<SendableRecordBatchStream> {
Ok(Box::pin(NeverendingRecordBatchStream))
}
fn statistics(&self) -> Result<Statistics> {
Ok(Statistics::new_unknown(&self.schema()))
}
}
#[tokio::test]
async fn test_task_cancellation() {
let work_dir = TempDir::new()
.unwrap()
.into_path()
.into_os_string()
.into_string()
.unwrap();
let shuffle_write = ShuffleWriterExec::try_new(
"job-id".to_owned(),
1,
Arc::new(NeverendingOperator::new()),
work_dir.clone(),
None,
)
.expect("creating shuffle writer");
let query_stage_exec = DefaultQueryStageExec::new(shuffle_write);
let executor_registration = ExecutorRegistration {
id: "executor".to_string(),
port: 0,
grpc_port: 0,
specification: None,
optional_host: None,
};
let ctx = SessionContext::new();
let executor = Executor::new(
executor_registration,
&work_dir,
ctx.runtime_env(),
None,
Arc::new(LoggingMetricsCollector {}),
2,
None,
);
let (sender, receiver) = tokio::sync::oneshot::channel();
let executor_clone = executor.clone();
tokio::task::spawn(async move {
let part = PartitionId {
job_id: "job-id".to_owned(),
stage_id: 1,
partition_id: 0,
};
let task_result = executor_clone
.execute_query_stage(1, part, Arc::new(query_stage_exec), ctx.task_ctx())
.await;
sender.send(task_result).expect("sending result");
});
for _ in 0..20 {
if executor
.cancel_task(1, "job-id".to_owned(), 1, 0)
.await
.expect("cancelling task")
{
break;
} else {
tokio::time::sleep(Duration::from_millis(50)).await;
}
}
let result = tokio::time::timeout(Duration::from_secs(5), receiver).await;
assert!(result.is_ok());
let inner_result = result.unwrap().unwrap();
assert!(inner_result.is_err());
}
}