use crate::common::{TreeNodeExt, now_ns, serialize_uuid, task_ctx_with_extension};
use crate::config_extension_ext::get_config_extension_propagation_headers;
use crate::coordinator::MetricsStore;
use crate::execution_plans::{ChildrenIsolatorUnionExec, DistributedLeafExec};
use crate::passthrough_headers::get_passthrough_headers;
use crate::protobuf::tonic_status_to_datafusion_error;
use crate::stage::LocalStage;
use crate::work_unit_feed::{build_work_unit_batch_msg, set_work_unit_send_time};
use crate::worker::generated::worker as pb;
use crate::worker::generated::worker::coordinator_to_worker_msg::Inner;
use crate::worker::generated::worker::set_plan_request::WorkUnitFeedDeclaration;
use crate::{
DISTRIBUTED_DATAFUSION_TASK_ID_LABEL, DistributedCodec, DistributedConfig,
DistributedTaskContext, DistributedWorkUnitFeedContext, TaskKey,
get_distributed_channel_resolver,
};
use datafusion::common::Result;
use datafusion::common::instant::Instant;
use datafusion::common::runtime::JoinSet;
use datafusion::common::tree_node::{Transformed, TreeNodeRecursion};
use datafusion::common::{DataFusionError, exec_datafusion_err};
use datafusion::execution::TaskContext;
use datafusion::physical_expr_common::metrics::{
Count, ExecutionPlanMetricsSet, Label, MetricBuilder, MetricValue, Time,
};
use datafusion::physical_plan::ExecutionPlan;
use datafusion_proto::physical_plan::AsExecutionPlan;
use datafusion_proto::protobuf::PhysicalPlanNode;
use futures::StreamExt;
use http::Extensions;
use prost::Message;
use std::fmt::Display;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::Request;
use tonic::metadata::MetadataMap;
use url::Url;
use uuid::Uuid;
const WORK_UNIT_FEED_CHUNK_SIZE: usize = 256;
#[derive(Clone)]
pub(super) struct CoordinatorToWorkerMetrics {
pub(super) plan_bytes_sent: Count,
pub(super) plan_send_latency: Arc<LatencyMetric>,
pub(super) instantiation_time: u64,
}
impl CoordinatorToWorkerMetrics {
pub(super) fn new(metrics: &ExecutionPlanMetricsSet) -> Self {
Self {
plan_bytes_sent: MetricBuilder::new(metrics)
.with_label(Label::new(DISTRIBUTED_DATAFUSION_TASK_ID_LABEL, "0"))
.global_counter("plan_bytes_sent"),
plan_send_latency: Arc::new(LatencyMetric::new(
"plan_send_latency",
|b| b.with_label(Label::new(DISTRIBUTED_DATAFUSION_TASK_ID_LABEL, "0")),
metrics,
)),
instantiation_time: now_ns(),
}
}
}
pub(super) struct CoordinatorToWorkerTaskSpawner<'a> {
plan: &'a Arc<dyn ExecutionPlan>,
query_id: Uuid,
stage_id: usize,
task_count: usize,
task_ctx: &'a TaskContext,
metrics: &'a CoordinatorToWorkerMetrics,
task_metrics: &'a Option<Arc<MetricsStore>>,
join_set: &'a mut JoinSet<Result<()>>,
}
impl<'a> CoordinatorToWorkerTaskSpawner<'a> {
pub(super) fn new(
stage: &'a LocalStage,
metrics: &'a CoordinatorToWorkerMetrics,
task_metrics: &'a Option<Arc<MetricsStore>>,
task_ctx: &'a TaskContext,
join_set: &'a mut JoinSet<Result<()>>,
) -> Result<Self> {
Ok(Self {
plan: &stage.plan,
query_id: stage.query_id,
stage_id: stage.num,
task_count: stage.tasks,
task_ctx,
metrics,
task_metrics,
join_set,
})
}
pub(super) fn send_plan_task(
&mut self,
ctx: Arc<TaskContext>,
task_i: usize,
url: Url,
) -> Result<(
UnboundedSender<pb::CoordinatorToWorkerMsg>,
UnboundedReceiver<pb::WorkerToCoordinatorMsg>,
)> {
let d_cfg = DistributedConfig::from_config_options(ctx.session_config().options())?;
let wuf_registry = &d_cfg.__private_work_unit_feed_registry;
let mut work_unit_feed_declarations = vec![];
let d_ctx = DistributedTaskContext {
task_index: task_i,
task_count: self.task_count,
};
let plan = Arc::clone(self.plan);
let specialized = plan.transform_down_with_dt_ctx(d_ctx, |plan, d_ctx| {
if let Some(wuf) = wuf_registry.get_work_unit_feed(&plan) {
work_unit_feed_declarations.push(WorkUnitFeedDeclaration {
id: serialize_uuid(&wuf.id()),
partitions: plan.properties().partitioning.partition_count() as u64,
});
};
if let Some(ciu) = plan.downcast_ref::<ChildrenIsolatorUnionExec>() {
let ciu = ciu.to_task_specialized(d_ctx.task_index);
return Ok(Transformed::yes(Arc::new(ciu)));
};
if let Some(dle) = plan.downcast_ref::<DistributedLeafExec>() {
let specialized = dle.to_task_specialized(d_ctx.task_index);
return Ok(Transformed::yes(specialized));
}
Ok(Transformed::no(plan))
})?;
let codec = DistributedCodec::new_combined_with_user(self.task_ctx.session_config());
let plan_proto =
PhysicalPlanNode::try_from_physical_plan(specialized.data, &codec)?.encode_to_vec();
let plan_size = plan_proto.len();
let task_key = TaskKey {
query_id: serialize_uuid(&self.query_id),
stage_id: self.stage_id as u64,
task_number: task_i as u64,
};
let msg = pb::CoordinatorToWorkerMsg {
inner: Some(Inner::SetPlanRequest(pb::SetPlanRequest {
plan_proto,
task_count: self.task_count as u64,
task_key: Some(task_key.clone()),
work_unit_feed_declarations,
target_worker_url: url.to_string(),
query_start_time_ns: self.metrics.instantiation_time,
})),
};
let (coordinator_to_worker_tx, coordinator_to_worker_rx) =
tokio::sync::mpsc::unbounded_channel();
let (worker_to_coordinator_tx, worker_to_coordinator_rx) =
tokio::sync::mpsc::unbounded_channel();
let channel_resolver = get_distributed_channel_resolver(ctx.as_ref());
let mut headers = get_config_extension_propagation_headers(ctx.session_config())?;
headers.extend(get_passthrough_headers(ctx.session_config()));
let request = Request::from_parts(
MetadataMap::from_headers(headers),
Extensions::default(),
futures::stream::once(async { msg }).chain(
UnboundedReceiverStream::new(coordinator_to_worker_rx).map(set_work_unit_send_time),
),
);
let metrics = self.metrics.clone();
self.join_set.spawn(async move {
let start = Instant::now();
let mut client = channel_resolver.get_worker_client_for_url(&url).await?;
let response = client.coordinator_channel(request).await.map_err(|e| {
tonic_status_to_datafusion_error(&e).unwrap_or_else(|| {
exec_datafusion_err!("Error sending plan to worker {url}: {e}")
})
})?;
metrics.plan_send_latency.record(&start);
metrics.plan_bytes_sent.add(plan_size);
let mut worker_to_coordinator_stream = response.into_inner();
while let Some(msg_or_err) = worker_to_coordinator_stream.next().await {
let msg = match msg_or_err {
Ok(msg) => msg,
Err(err) => {
return Err(tonic_status_to_datafusion_error(err).unwrap_or_else(|| {
exec_datafusion_err!("Unknown error on worker to coordinator stream")
}));
}
};
if worker_to_coordinator_tx.send(msg).is_err() {
break; }
}
Ok::<_, DataFusionError>(())
});
Ok((coordinator_to_worker_tx, worker_to_coordinator_rx))
}
pub(super) fn metrics_collection_task(
&mut self,
task_i: usize,
mut worker_to_coordinator_rx: UnboundedReceiver<pb::WorkerToCoordinatorMsg>,
) {
let task_key = TaskKey {
query_id: serialize_uuid(&self.query_id),
stage_id: self.stage_id as u64,
task_number: task_i as u64,
};
let task_metrics = self.task_metrics.clone();
#[allow(clippy::disallowed_methods)]
tokio::spawn(async move {
while let Some(msg) = worker_to_coordinator_rx.recv().await {
let Some(inner) = msg.inner else { continue };
match inner {
pb::worker_to_coordinator_msg::Inner::TaskMetrics(pre_order_metrics) => {
if let Some(task_metrics) = &task_metrics {
task_metrics.insert(task_key.clone(), pre_order_metrics);
}
}
}
}
});
}
pub(super) fn work_unit_feed_task(
&mut self,
ctx: Arc<TaskContext>,
task_i: usize,
tx: UnboundedSender<pb::CoordinatorToWorkerMsg>,
) -> Result<()> {
let d_cfg = DistributedConfig::from_config_options(ctx.session_config().options())?;
let wuf_registry = &d_cfg.__private_work_unit_feed_registry;
let d_ctx = DistributedTaskContext {
task_index: task_i,
task_count: self.task_count,
};
let mut futures = vec![];
self.plan.apply_with_dt_ctx(d_ctx, |plan, d_ctx| {
let Some(wuf) = wuf_registry.get_work_unit_feed(plan) else {
return Ok(TreeNodeRecursion::Continue);
};
let partitions = plan.properties().partitioning.partition_count();
let start_partition = partitions * d_ctx.task_index;
let end_partition = start_partition + partitions;
let dist_feed_ctx = DistributedWorkUnitFeedContext {
fan_out_tasks: d_ctx.task_count,
};
let t_ctx = Arc::new(task_ctx_with_extension(&ctx, dist_feed_ctx));
let mut feeds = Vec::with_capacity(end_partition - start_partition);
for (partition, feed_idx) in (start_partition..end_partition).enumerate() {
let feed = wuf
.feed(feed_idx, Arc::clone(&t_ctx))?
.map(move |el| (partition, el));
feeds.push(feed);
}
let interleaved_feed = futures::stream::select_all(feeds);
let mut chunked_interleaved_feed =
interleaved_feed.ready_chunks(WORK_UNIT_FEED_CHUNK_SIZE);
let id = wuf.id();
let tx = tx.clone();
futures.push(Box::pin(async move {
while let Some(chunk) = chunked_interleaved_feed.next().await {
if tx.send(build_work_unit_batch_msg(&id, chunk)?).is_err() {
break; };
}
Ok::<_, DataFusionError>(())
}));
Ok(TreeNodeRecursion::Continue)
})?;
self.join_set.spawn(async move {
futures::future::try_join_all(futures).await?;
Ok(())
});
Ok(())
}
}
pub(super) struct LatencyMetric {
max: Time,
avg: Time,
max_latency_micros: AtomicU64,
sum_latency_micros: AtomicU64,
count_latency_micros: AtomicU64,
}
impl Drop for LatencyMetric {
fn drop(&mut self) {
self.max.add_duration(Duration::from_micros(
self.max_latency_micros.load(Ordering::Relaxed),
));
self.avg.add_duration(Duration::from_micros(
self.sum_latency_micros.load(Ordering::Relaxed)
/ self.count_latency_micros.load(Ordering::Relaxed).max(1),
));
}
}
impl LatencyMetric {
pub(super) fn new(
name: impl Display,
builder: impl Fn(MetricBuilder) -> MetricBuilder,
metrics: &ExecutionPlanMetricsSet,
) -> Self {
let max = Time::new();
builder(MetricBuilder::new(metrics)).build(MetricValue::Time {
name: format!("{name}_max").into(),
time: max.clone(),
});
let avg = Time::new();
builder(MetricBuilder::new(metrics)).build(MetricValue::Time {
name: format!("{name}_avg").into(),
time: avg.clone(),
});
Self {
max,
avg,
max_latency_micros: AtomicU64::new(0),
sum_latency_micros: AtomicU64::new(0),
count_latency_micros: AtomicU64::new(0),
}
}
fn record(&self, start: &Instant) {
let micros = start.elapsed().as_micros() as u64;
self.max_latency_micros.fetch_max(micros, Ordering::Relaxed);
self.sum_latency_micros.fetch_add(micros, Ordering::Relaxed);
self.count_latency_micros.fetch_add(1, Ordering::Relaxed);
}
}