use crate::common::on_drop_stream;
use crate::metrics::LatencyMetricExt;
use crate::networking::get_distributed_channel_resolver;
use crate::passthrough_headers::get_passthrough_headers;
use crate::protobuf::{datafusion_error_to_tonic_status, map_flight_to_datafusion_error};
use crate::worker::generated::worker::FlightAppMetadata;
use crate::worker::generated::worker::{ExecuteTaskRequest, TaskKey};
use crate::{BytesMetricExt, ChannelResolver, Stage};
use arrow_flight::FlightData;
use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::error::FlightError;
use dashmap::DashMap;
use datafusion::arrow::array::RecordBatch;
use datafusion::common::instant::Instant;
use datafusion::common::runtime::SpawnedTask;
use datafusion::common::{DataFusionError, Result, internal_err};
use datafusion::execution::TaskContext;
use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion::physical_expr_common::metrics::{ExecutionPlanMetricsSet, MetricValue};
use datafusion::physical_plan::metrics::{MetricBuilder, Time};
use futures::{Stream, TryStreamExt};
use http::Extensions;
use pin_project::{pin_project, pinned_drop};
use prost::Message;
use std::borrow::Cow;
use std::fmt::{Debug, Formatter};
use std::ops::Range;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, OnceLock};
use std::task::{Context, Poll};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio_stream::StreamExt;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::sync::CancellationToken;
use tonic::metadata::MetadataMap;
use tonic::{Request, Status};
pub(crate) struct WorkerConnectionPool {
connections: Vec<OnceLock<Result<WorkerConnection, Arc<DataFusionError>>>>,
pub(crate) metrics: ExecutionPlanMetricsSet,
}
impl WorkerConnectionPool {
pub(crate) fn new(input_tasks: usize) -> Self {
let mut connections = Vec::with_capacity(input_tasks);
for _ in 0..input_tasks {
connections.push(OnceLock::new());
}
Self {
connections,
metrics: ExecutionPlanMetricsSet::default(),
}
}
pub(crate) fn get_or_init_worker_connection(
&self,
input_stage: &Stage,
target_partitions: Range<usize>,
target_task: usize,
ctx: &Arc<TaskContext>,
) -> Result<&WorkerConnection> {
let Some(worker_connection) = self.connections.get(target_task) else {
return internal_err!(
"WorkerConnections: Task index {target_task} not found, only have {} tasks",
self.connections.len()
);
};
let conn = worker_connection.get_or_init(|| {
WorkerConnection::init(
input_stage,
target_partitions,
target_task,
ctx,
&self.metrics,
)
.map_err(Arc::new)
});
match conn {
Ok(v) => Ok(v),
Err(err) => Err(DataFusionError::Shared(Arc::clone(err))),
}
}
}
type WorkerMsg = Result<(FlightData, FlightAppMetadata, MemoryReservation), Status>;
pub(crate) struct WorkerConnection {
task: Arc<SpawnedTask<()>>,
not_consumed_streams: Arc<AtomicUsize>,
cancel_token: CancellationToken,
per_partition_rx: DashMap<usize, UnboundedReceiver<WorkerMsg>>,
curr_mem_used: Arc<AtomicUsize>,
elapsed_compute: Time,
}
impl WorkerConnection {
fn init(
input_stage: &Stage,
target_partition_range: Range<usize>,
target_task: usize,
ctx: &Arc<TaskContext>,
metrics: &ExecutionPlanMetricsSet,
) -> Result<Self> {
let channel_resolver = get_distributed_channel_resolver(ctx.as_ref());
let curr_mem_used = Arc::new(AtomicUsize::new(0));
let curr_mem_used_clone = Arc::clone(&curr_mem_used);
let mut curr_max_mem = 0;
let max_mem_used = MetricBuilder::new(metrics).global_gauge("max_mem_used");
let bytes_transferred = MetricBuilder::new(metrics).bytes_counter("bytes_transferred");
let min_latency = MetricBuilder::new(metrics).min_latency("network_latency_min");
let max_latency = MetricBuilder::new(metrics).max_latency("network_latency_max");
let p50_latency = MetricBuilder::new(metrics).p50_latency("network_latency_p50");
let p95_latency = MetricBuilder::new(metrics).p95_latency("network_latency_p95");
let first_latency = MetricBuilder::new(metrics).first_latency("network_latency_first");
let sum_latency = Time::new();
MetricBuilder::new(metrics).build(MetricValue::Time {
name: Cow::Borrowed("network_latency_sum"),
time: sum_latency.clone(),
});
let latency_count = MetricBuilder::new(metrics).counter("network_latency_count", 0);
let elapsed_compute = Time::new();
let elapsed_compute_clone = elapsed_compute.clone();
MetricBuilder::new(metrics).build(MetricValue::ElapsedCompute(elapsed_compute.clone()));
let headers = get_passthrough_headers(ctx.session_config());
let request = Request::from_parts(
MetadataMap::from_headers(headers),
Extensions::default(),
ExecuteTaskRequest {
target_partition_start: target_partition_range.start as u64,
target_partition_end: target_partition_range.end as u64,
task_key: Some(TaskKey {
query_id: input_stage.query_id.as_bytes().to_vec(),
stage_id: input_stage.num as u64,
task_number: target_task as u64,
}),
},
);
let Some(task) = input_stage.tasks.get(target_task) else {
return internal_err!("ProgrammingError: Task {target_task} not found");
};
let Some(url) = task.url.clone() else {
return internal_err!("ProgrammingError: task is unassigned, cannot proceed");
};
let mut per_partition_tx = Vec::with_capacity(target_partition_range.len());
let per_partition_rx = DashMap::with_capacity(target_partition_range.len());
for partition in target_partition_range.clone() {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<WorkerMsg>();
per_partition_tx.push(tx);
per_partition_rx.insert(partition, rx);
}
let memory_pool = Arc::clone(ctx.memory_pool());
let cancel_token = CancellationToken::new();
let cancel = cancel_token.clone();
let task = SpawnedTask::spawn(async move {
let mut client = match channel_resolver.get_worker_client_for_url(&url).await {
Ok(v) => v,
Err(err) => {
return fanout(&per_partition_tx, datafusion_error_to_tonic_status(&err));
}
};
let mut interleaved_stream = match client.execute_task(request).await {
Ok(v) => v.into_inner(),
Err(err) => return fanout(&per_partition_tx, err),
};
let consumer = MemoryConsumer::new("WorkerConnection");
loop {
let msg = tokio::select! {
biased;
_ = cancel.cancelled() => return,
msg = interleaved_stream.next() => {
match msg {
Some(Ok(v)) => v,
Some(Err(err)) => return fanout(&per_partition_tx, err),
None => return, }
}
};
let msg_received_time = SystemTime::now();
let flight_metadata = match FlightAppMetadata::decode(msg.app_metadata.as_ref()) {
Ok(v) => v,
Err(err) => {
return fanout(&per_partition_tx, Status::internal(err.to_string()));
}
};
let sent_time = UNIX_EPOCH + Duration::from_nanos(flight_metadata.created_timestamp_unix_nanos);
if flight_metadata.created_timestamp_unix_nanos > 0
&& let Ok(delta) = msg_received_time.duration_since(sent_time) {
min_latency.add_duration(delta);
max_latency.add_duration(delta);
p50_latency.add_duration(delta);
p95_latency.add_duration(delta);
first_latency.add_duration(delta);
sum_latency.add_duration(delta);
latency_count.add(1);
}
let partition = flight_metadata.partition as usize;
let sender_i = partition - target_partition_range.start;
let Some(o_tx) = per_partition_tx.get(sender_i) else {
let msg = format!(
"Received partition {partition} in Flight metadata, but available partitions are {target_partition_range:?}"
);
return fanout(&per_partition_tx, Status::internal(msg));
};
let reservation = consumer.clone_with_new_id().register(&memory_pool);
let size = msg.encoded_len();
bytes_transferred.add_bytes(size);
let curr_mem_used = curr_mem_used.fetch_add(size, Ordering::Relaxed);
if curr_mem_used > curr_max_mem {
curr_max_mem = curr_mem_used;
max_mem_used.set(curr_max_mem);
}
reservation.grow(size);
if o_tx.send(Ok((msg, flight_metadata, reservation))).is_err() {
return; };
}
}.with_elapsed_compute(elapsed_compute));
Ok(Self {
task: Arc::new(task),
cancel_token,
not_consumed_streams: Arc::new(AtomicUsize::new(per_partition_rx.len())),
per_partition_rx,
curr_mem_used: curr_mem_used_clone,
elapsed_compute: elapsed_compute_clone,
})
}
pub(crate) fn stream_partition(
&self,
partition: usize,
on_metadata: impl Fn(FlightAppMetadata) + Send + Sync + 'static,
) -> Result<impl Stream<Item = Result<RecordBatch>> + 'static> {
let Some((_, partition_receiver)) = self.per_partition_rx.remove(&partition) else {
return internal_err!(
"WorkerConnection has no stream for target partition {partition}. Was it already consumed?"
);
};
let task = Arc::clone(&self.task);
let cancel_token = self.cancel_token.clone();
let curr_mem_used = Arc::clone(&self.curr_mem_used);
let stream = UnboundedReceiverStream::new(partition_receiver);
let stream = stream.map_err(|err| FlightError::Tonic(Box::new(err)));
let stream = stream.map_ok(move |(data, meta, reservation)| {
curr_mem_used.fetch_sub(reservation.size(), Ordering::Relaxed);
drop(reservation); let _ = &task; on_metadata(meta);
data
});
let stream = FlightRecordBatchStream::new_from_flight_data(stream);
let stream = stream.map_err(map_flight_to_datafusion_error);
let stream = stream.with_elapsed_compute(self.elapsed_compute.clone());
let not_consumed_streams = Arc::clone(&self.not_consumed_streams);
Ok(on_drop_stream(stream, move || {
let remaining_streams = not_consumed_streams.fetch_sub(1, Ordering::SeqCst) - 1;
if remaining_streams == 0 {
cancel_token.cancel();
}
}))
}
}
fn fanout(o_txs: &[UnboundedSender<WorkerMsg>], err: Status) {
for o_tx in o_txs {
let _ = o_tx.send(Err(err.clone()));
}
}
impl Debug for WorkerConnectionPool {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WorkerConnections")
.field("num_connections", &self.connections.len())
.finish()
}
}
impl Clone for WorkerConnectionPool {
fn clone(&self) -> Self {
Self::new(self.connections.len())
}
}
impl Debug for WorkerConnection {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WorkerConnection").finish()
}
}
trait ElapsedComputeFutureExt: Future + Sized {
fn with_elapsed_compute(self, elapsed_compute: Time) -> ElapsedComputeFuture<Self>;
}
trait ElapsedComputeStreamExt: Stream + Sized {
fn with_elapsed_compute(self, elapsed_compute: Time) -> ElapsedComputeStream<Self>;
}
impl<O, F: Future<Output = O>> ElapsedComputeFutureExt for F {
fn with_elapsed_compute(self, elapsed_compute: Time) -> ElapsedComputeFuture<Self> {
ElapsedComputeFuture {
inner: self,
curr: Duration::default(),
elapsed_compute,
}
}
}
impl<O, S: Stream<Item = O>> ElapsedComputeStreamExt for S {
fn with_elapsed_compute(self, elapsed_compute: Time) -> ElapsedComputeStream<Self> {
ElapsedComputeStream {
inner: self,
curr: Duration::default(),
elapsed_compute,
}
}
}
#[pin_project(PinnedDrop)]
struct ElapsedComputeStream<T> {
#[pin]
inner: T,
curr: Duration,
elapsed_compute: Time,
}
#[pinned_drop]
impl<T> PinnedDrop for ElapsedComputeStream<T> {
fn drop(self: Pin<&mut Self>) {
if self.curr > Duration::default() {
let self_projected = self.project();
self_projected
.elapsed_compute
.add_duration(*self_projected.curr);
}
}
}
impl<O, F: Stream<Item = O>> Stream for ElapsedComputeStream<F> {
type Item = O;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let self_projected = self.project();
let start = Instant::now();
let result = self_projected.inner.poll_next(cx);
*self_projected.curr += start.elapsed();
if result.is_ready() {
self_projected
.elapsed_compute
.add_duration(*self_projected.curr);
*self_projected.curr = Duration::default();
}
result
}
}
#[pin_project(PinnedDrop)]
struct ElapsedComputeFuture<T> {
#[pin]
inner: T,
curr: Duration,
elapsed_compute: Time,
}
#[pinned_drop]
impl<T> PinnedDrop for ElapsedComputeFuture<T> {
fn drop(self: Pin<&mut Self>) {
if self.curr > Duration::default() {
let self_projected = self.project();
self_projected
.elapsed_compute
.add_duration(*self_projected.curr);
}
}
}
impl<O, F: Future<Output = O>> Future for ElapsedComputeFuture<F> {
type Output = O;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let self_projected = self.project();
let start = Instant::now();
let result = self_projected.inner.poll(cx);
*self_projected.curr += start.elapsed();
if result.is_ready() {
self_projected
.elapsed_compute
.add_duration(*self_projected.curr);
*self_projected.curr = Duration::default();
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
use futures::stream::unfold;
#[tokio::test]
async fn elapsed_compute_future() {
async fn cheap() {
tokio::time::sleep(Duration::from_millis(1)).await;
}
async fn expensive() {
let mut _count = 0f64;
for i in 0..100000 {
tokio::task::yield_now().await;
_count /= i as f64
}
}
let cheap_time = Time::new();
cheap().with_elapsed_compute(cheap_time.clone()).await;
println!("cheap future: {}", cheap_time.value());
let expensive_time = Time::new();
expensive()
.with_elapsed_compute(expensive_time.clone())
.await;
println!("expensive future: {}", expensive_time.value());
assert!(expensive_time.value() > cheap_time.value());
}
#[tokio::test]
async fn elapsed_compute_stream() {
fn cheap() -> impl Stream<Item = i64> {
unfold(0i64, |state| async move {
if state < 10 {
tokio::time::sleep(Duration::from_micros(10)).await;
Some((state, state + 1))
} else {
None
}
})
}
fn expensive() -> impl Stream<Item = i64> {
unfold(0i64, |state| async move {
if state < 10 {
let mut _count = 0f64;
for i in 1..100000 {
_count += (i as f64).sqrt();
}
tokio::task::yield_now().await;
Some((state, state + 1))
} else {
None
}
})
}
let cheap_time = Time::new();
cheap()
.with_elapsed_compute(cheap_time.clone())
.collect::<Vec<_>>()
.await;
println!("cheap future: {}", cheap_time.value());
let expensive_time = Time::new();
expensive()
.with_elapsed_compute(expensive_time.clone())
.collect::<Vec<_>>()
.await;
println!("expensive future: {}", expensive_time.value());
assert!(expensive_time.value() > cheap_time.value());
}
}