use crate::common::{TreeNodeExt, now_ns, on_drop_stream};
use crate::metrics::proto::df_metrics_set_to_proto;
use crate::protobuf::datafusion_error_to_tonic_status;
use crate::worker::generated::worker::{FlightAppMetadata, TaskMetrics};
use crate::worker::worker_service::{TaskDataEntries, Worker};
use crate::{DistributedConfig, DistributedTaskContext};
use arrow_flight::encode::{DictionaryHandling, FlightDataEncoder, FlightDataEncoderBuilder};
use arrow_flight::error::FlightError;
use arrow_select::dictionary::garbage_collect_any_dictionary;
use datafusion::arrow::array::{Array, AsArray, RecordBatch, RecordBatchOptions};
use datafusion::common::tree_node::TreeNodeRecursion;
use datafusion::common::{Result, exec_err, internal_err};
use crate::worker::generated::worker::ExecuteTaskRequest;
use crate::worker::generated::worker::worker_service_server::WorkerService;
use crate::worker::spawn_select_all::spawn_select_all;
use crate::worker::task_data::TaskDataMetrics;
use datafusion::arrow::ipc::CompressionType;
use datafusion::arrow::ipc::writer::IpcWriteOptions;
use datafusion::common::exec_datafusion_err;
use datafusion::error::DataFusionError;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use futures::TryStreamExt;
use prost::Message;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::Ordering;
use std::time::Duration;
use tokio::sync::oneshot::Sender;
use tokio_stream::StreamExt;
use tonic::{Request, Response, Status};
const RECORD_BATCH_BUFFER_SIZE: usize = 2;
const WAIT_PLAN_TIMEOUT_SECS: u64 = 10;
pub(crate) async fn execute_local_task(
task_data_entries: &Arc<TaskDataEntries>,
body: ExecuteTaskRequest,
) -> Result<(Vec<SendableRecordBatchStream>, Arc<TaskContext>)> {
let Some(key) = body.task_key.as_ref().cloned() else {
return internal_err!("Missing task_key in LocalWorkerConnection");
};
let Some(producer_head) = body.producer_head.as_ref().cloned() else {
return internal_err!("Missing producer_head");
};
let entry = task_data_entries
.get_with(key.clone(), async { Default::default() })
.await;
let task_data = entry
.read(Duration::from_secs(WAIT_PLAN_TIMEOUT_SECS))
.await
.map_err(|e| exec_datafusion_err!("Worker::execute_task timed-out while waiting for the plan to be set by the coordinator. ({e})"))?
.map_err(DataFusionError::Shared)?;
task_data.task_data_metrics.mark_execution_started_once();
let plan = task_data.plan(producer_head)?;
let task_ctx = task_data.task_ctx;
let d_cfg = DistributedConfig::from_config_options(task_ctx.session_config().options())?;
let d_ctx = *DistributedTaskContext::from_ctx(&task_ctx).as_ref();
let send_metrics = d_cfg.collect_metrics;
let partition_count = plan.properties().partitioning.partition_count();
let plan_name = plan.name();
let n_streams = body.target_partition_end - body.target_partition_start;
let mut streams = Vec::with_capacity(n_streams as usize);
for partition in body.target_partition_start..body.target_partition_end {
if partition >= partition_count as u64 {
return exec_err!(
"partition {partition} not available. The head plan {plan_name} of the stage just has {partition_count} partitions"
);
}
let stream = plan.execute(partition as usize, Arc::clone(&task_ctx))?;
let stream_schema = plan.schema();
let plan = Arc::clone(&plan);
let task_data_entries = Arc::clone(task_data_entries);
let num_partitions_remaining = Arc::clone(&task_data.num_partitions_remaining);
let metrics_tx = Arc::clone(&task_data.metrics_tx);
let task_data_metrics = Arc::clone(&task_data.task_data_metrics);
let key = key.clone();
let stream = on_drop_stream(stream, move || {
if num_partitions_remaining.fetch_sub(1, Ordering::SeqCst) == 1 {
#[allow(clippy::disallowed_methods)]
tokio::spawn(async move {
task_data_entries.invalidate(&key).await;
});
task_data_metrics.mark_execution_finished();
if send_metrics {
send_metrics_via_channel(&metrics_tx, &plan, d_ctx, &task_data_metrics);
}
}
});
streams.push(Box::pin(RecordBatchStreamAdapter::new(stream_schema, stream)) as _);
}
Ok((streams, task_ctx))
}
pub(crate) async fn execute_remote_task(
task_data_entries: &Arc<TaskDataEntries>,
request: Request<ExecuteTaskRequest>,
) -> Result<Response<<Worker as WorkerService>::ExecuteTaskStream>, Status> {
let body = request.into_inner();
let partition_range = body.target_partition_start..body.target_partition_end;
let (arrow_streams, task_ctx) = execute_local_task(task_data_entries, body)
.await
.map_err(datafusion_error_to_tonic_status)?;
let d_cfg = DistributedConfig::from_config_options(task_ctx.session_config().options())
.map_err(datafusion_error_to_tonic_status)?;
let compression = match d_cfg.compression.as_str() {
"lz4" => Some(CompressionType::LZ4_FRAME),
"zstd" => Some(CompressionType::ZSTD),
"none" => None,
v => Err(Status::invalid_argument(format!(
"Unknown compression type {v}"
)))?,
};
let mut flight_streams = Vec::with_capacity(arrow_streams.len());
for (partition, arrow_stream) in partition_range.zip(arrow_streams) {
let flight_stream = build_flight_data_stream(arrow_stream, compression)?.map(move |msg| {
let flight_data = FlightAppMetadata {
partition,
created_timestamp_unix_nanos: now_ns(),
};
msg.map(|v| v.with_app_metadata(flight_data.encode_to_vec()))
});
flight_streams.push(flight_stream);
}
let memory_pool = Arc::clone(&task_ctx.runtime_env().memory_pool);
let stream = spawn_select_all(flight_streams, memory_pool, RECORD_BATCH_BUFFER_SIZE);
Ok(Response::new(Box::pin(stream.map_err(|err| match err {
FlightError::Tonic(status) => *status,
_ => Status::internal(format!("Error during flight stream: {err}")),
}))))
}
fn build_flight_data_stream(
stream: SendableRecordBatchStream,
compression_type: Option<CompressionType>,
) -> Result<FlightDataEncoder, Status> {
let stream = FlightDataEncoderBuilder::new()
.with_options(
IpcWriteOptions::default()
.try_with_compression(compression_type)
.map_err(|err| Status::internal(err.to_string()))?,
)
.with_schema(stream.schema())
.with_dictionary_handling(DictionaryHandling::Resend)
.with_max_flight_data_size(usize::MAX)
.build(
stream
.and_then(|rb| std::future::ready(garbage_collect_arrays(rb)))
.map_err(|err| FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(err)))),
);
Ok(stream)
}
fn send_metrics_via_channel(
metrics_tx: &Arc<Mutex<Option<Sender<TaskMetrics>>>>,
plan: &Arc<dyn ExecutionPlan>,
dt_ctx: DistributedTaskContext,
task_data_metrics: &Arc<TaskDataMetrics>,
) {
let mut pre_order_plan_metrics = vec![];
let _ = plan.apply_with_dt_ctx(dt_ctx, |node, _| {
pre_order_plan_metrics.push(
node.metrics()
.and_then(|m| df_metrics_set_to_proto(&m).ok())
.unwrap_or_default(),
);
Ok(TreeNodeRecursion::Continue)
});
let tx = {
let mut guard = match metrics_tx.lock() {
Ok(g) => g,
Err(_) => return,
};
guard.take()
};
let Some(tx) = tx else { return };
let _ = tx.send(TaskMetrics {
pre_order_plan_metrics,
task_metrics: Some(task_data_metrics.to_proto_metrics_set()),
});
}
fn garbage_collect_arrays(batch: RecordBatch) -> Result<RecordBatch, DataFusionError> {
let (schema, arrays, row_count) = batch.into_parts();
let arrays = arrays
.into_iter()
.map(|array| {
if let Some(array) = array.as_any_dictionary_opt() {
garbage_collect_any_dictionary(array)
} else if let Some(array) = array.as_string_view_opt() {
Ok(Arc::new(array.gc()) as Arc<dyn Array>)
} else if let Some(array) = array.as_binary_view_opt() {
Ok(Arc::new(array.gc()) as Arc<dyn Array>)
} else {
Ok(array)
}
})
.collect::<Result<Vec<_>, _>>()?;
Ok(RecordBatch::try_new_with_options(
schema,
arrays,
&RecordBatchOptions::new().with_row_count(Some(row_count)),
)?)
}