use crate::common::{now_ns, serialize_uuid};
use crate::worker::generated::worker as pb;
use crate::{BytesMetricExt, LatencyMetricExt, WorkUnit};
use datafusion::common::{HashMap, Result, exec_err};
use datafusion::execution::TaskContext;
use datafusion::physical_expr_common::metrics::MetricBuilder;
use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
use datafusion_proto::protobuf::proto_error;
use futures::StreamExt;
use futures::stream::BoxStream;
use std::sync::{Arc, Mutex};
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio_stream::wrappers::UnboundedReceiverStream;
use uuid::Uuid;
pub(crate) type WorkUnitTx = UnboundedSender<Result<pb::WorkUnit>>;
pub(crate) type WorkUnitRx = UnboundedReceiver<Result<pb::WorkUnit>>;
pub(crate) type RemoteWorkUnitFeedRxs = HashMap<(Uuid, usize), Mutex<Option<WorkUnitRx>>>;
pub(crate) type RemoteWorkUnitFeedTxs = HashMap<(Uuid, usize), WorkUnitTx>;
#[derive(Default)]
pub(crate) struct RemoteWorkUnitFeedRegistry {
pub(crate) receivers: RemoteWorkUnitFeedRxs,
pub(crate) senders: RemoteWorkUnitFeedTxs,
}
impl RemoteWorkUnitFeedRegistry {
pub(crate) fn add(&mut self, id: Uuid, partitions: usize) {
for partition in 0..partitions {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
self.receivers.insert((id, partition), Mutex::new(Some(rx)));
self.senders.insert((id, partition), tx);
}
}
}
pub(crate) fn build_work_unit_batch_msg(
id: &Uuid,
work_unit_batch: Vec<(usize, Result<Box<dyn WorkUnit>>)>,
) -> Result<pb::CoordinatorToWorkerMsg> {
Ok(pb::CoordinatorToWorkerMsg {
inner: Some(pb::coordinator_to_worker_msg::Inner::WorkUnitBatch(
pb::WorkUnitBatch {
batch: work_unit_batch
.into_iter()
.map(|(partition, work_unit)| {
Ok(pb::WorkUnit {
id: serialize_uuid(id),
partition: partition as u64,
body: work_unit?.encode_to_bytes(),
created_timestamp_unix_nanos: now_ns(),
sent_timestamp_unix_nanos: 0,
received_timestamp_unix_nanos: 0,
processed_timestamp_unix_nanos: 0,
})
})
.collect::<Result<_>>()?,
},
)),
})
}
pub(crate) fn set_work_unit_send_time(
mut msg: pb::CoordinatorToWorkerMsg,
) -> pb::CoordinatorToWorkerMsg {
if let pb::CoordinatorToWorkerMsg {
inner: Some(pb::coordinator_to_worker_msg::Inner::WorkUnitBatch(work_unit_batch)),
} = &mut msg
{
for work_unit in &mut work_unit_batch.batch {
work_unit.sent_timestamp_unix_nanos = now_ns();
}
}
msg
}
pub(crate) fn set_work_unit_received_time(
mut msg: pb::CoordinatorToWorkerMsg,
) -> pb::CoordinatorToWorkerMsg {
if let pb::CoordinatorToWorkerMsg {
inner: Some(pb::coordinator_to_worker_msg::Inner::WorkUnitBatch(work_unit_batch)),
} = &mut msg
{
for work_unit in &mut work_unit_batch.batch {
work_unit.received_timestamp_unix_nanos = now_ns();
}
}
msg
}
#[derive(Debug, Clone)]
pub(crate) struct RemoteFeedProvider {
pub(crate) id: Uuid,
pub(crate) metrics: ExecutionPlanMetricsSet,
}
impl RemoteFeedProvider {
pub(crate) fn feed<T: WorkUnit + Default>(
&self,
partition: usize,
ctx: Arc<TaskContext>,
) -> Result<BoxStream<'static, Result<T>>> {
let bdr = || MetricBuilder::new(&self.metrics);
let bytes_transferred = bdr().bytes_counter("work_unit_bytes");
let msg_count = bdr().global_counter("work_unit_count");
let send_latency_max = bdr().max_latency("work_unit_send_latency_max");
let send_latency_p50 = bdr().p50_latency("work_unit_send_latency_p50");
let received_latency_max = bdr().max_latency("work_unit_received_latency_max");
let received_latency_p50 = bdr().p50_latency("work_unit_received_latency_p50");
let processed_latency_max = bdr().max_latency("work_unit_processed_latency_max");
let processed_latency_p50 = bdr().p50_latency("work_unit_processed_latency_p50");
let elapsed_compute = bdr().elapsed_compute(partition);
let Some(rxs) = ctx
.session_config()
.get_extension::<RemoteWorkUnitFeedRxs>()
else {
return exec_err!("Missing RemoteWorkUnitFeedRegistry in context");
};
let id = self.id;
let Some(remote_feed) = rxs.get(&(id, partition)) else {
return exec_err!(
"Missing WorkUnit feed for id {id} and partition {partition}. Was the WorkUnitFeed registered with DistributedExt::with_distributed_work_unit_feed?"
);
};
let Some(receiver) = std::mem::take(&mut *remote_feed.lock().unwrap()) else {
return exec_err!(
"WorkUnit feed for id {id} and partition {partition} was already consumed"
);
};
Ok(UnboundedReceiverStream::new(receiver)
.map(move |work_unit_or_err| {
let mut work_unit = work_unit_or_err?;
let timer = elapsed_compute.timer();
let result = T::decode(work_unit.body.as_slice())
.map_err(|err| proto_error(format!("{err}")));
timer.done();
work_unit.processed_timestamp_unix_nanos = now_ns();
let pb::WorkUnit {
created_timestamp_unix_nanos: base,
sent_timestamp_unix_nanos,
received_timestamp_unix_nanos,
processed_timestamp_unix_nanos,
body,
..
} = work_unit;
bytes_transferred.add_bytes(body.len());
msg_count.add(1);
send_latency_max.add_nanos((sent_timestamp_unix_nanos - base) as usize);
send_latency_p50.add_nanos((sent_timestamp_unix_nanos - base) as usize);
received_latency_max.add_nanos((received_timestamp_unix_nanos - base) as usize);
received_latency_p50.add_nanos((received_timestamp_unix_nanos - base) as usize);
processed_latency_max.add_nanos((processed_timestamp_unix_nanos - base) as usize);
processed_latency_p50.add_nanos((processed_timestamp_unix_nanos - base) as usize);
result
})
.boxed())
}
pub(crate) fn metrics(&self) -> ExecutionPlanMetricsSet {
self.metrics.clone()
}
}