use std::{
collections::{BTreeMap, HashMap, HashSet},
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
use arrow::{
array::{RecordBatch, StringBuilder, TimestampMillisecondBuilder},
datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit},
error::ArrowError,
};
use datafusion_common::ScalarValue;
use datafusion_physical_expr::expressions::Literal;
use datafusion_physical_plan::{
ExecutionPlan, placeholder_row::PlaceholderRowExec, projection::ProjectionExec,
};
use futures::{StreamExt, stream::BoxStream};
use serde::Serialize;
use tokio::{
runtime::Handle,
sync::mpsc::{Receiver, Sender},
task::{AbortHandle, JoinHandle},
};
use crate::{
DistError, DistResult,
network::{StageInfo, TaskSetInfo},
planner::StageId,
};
pub fn is_plan_select_1(plan: &Arc<dyn ExecutionPlan>) -> bool {
let Some(proj) = plan.as_any().downcast_ref::<ProjectionExec>() else {
return false;
};
if !proj.input().as_any().is::<PlaceholderRowExec>() {
return false;
}
if proj.expr().len() != 1 {
return false;
}
let expr = &proj.expr()[0];
let Some(literal) = expr.expr.as_any().downcast_ref::<Literal>() else {
return false;
};
matches!(
literal.value(),
ScalarValue::Int32(Some(1)) | ScalarValue::Int64(Some(1))
)
}
pub fn timestamp_ms() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_millis() as i64
}
pub fn get_local_ip() -> String {
local_ip_address::local_ip()
.expect("Failed to get local IP")
.to_string()
}
pub struct ReceiverStreamBuilder<O> {
tx: Sender<DistResult<O>>,
rx: Receiver<DistResult<O>>,
task: Option<JoinHandle<DistResult<()>>>,
}
impl<O: Send + 'static> ReceiverStreamBuilder<O> {
pub fn new(capacity: usize) -> Self {
let (tx, rx) = tokio::sync::mpsc::channel(capacity);
Self { tx, rx, task: None }
}
pub fn tx(&self) -> Sender<DistResult<O>> {
self.tx.clone()
}
pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
where
F: Future<Output = DistResult<()>>,
F: Send + 'static,
{
assert!(
self.task.is_none(),
"ReceiverStreamBuilder supports a single task"
);
let join_handle = handle.spawn(task);
let abort_handle = join_handle.abort_handle();
self.task = Some(join_handle);
abort_handle
}
pub fn build(self) -> BoxStream<'static, DistResult<O>> {
let Self { tx, rx, task } = self;
drop(tx);
let check = async move {
let task = task?;
match task.await {
Ok(Ok(())) => None,
Ok(Err(error)) => Some(Err(error)),
Err(e) => Some(Err(DistError::internal(format!("Tokio join error: {e}")))),
}
};
let check_stream = futures::stream::once(check)
.filter_map(|item| async move { item });
let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
let next_item = rx.recv().await;
next_item.map(|next_item| (next_item, rx))
});
futures::stream::select(rx_stream, check_stream).boxed()
}
}
#[derive(Debug)]
pub struct JobsArrowConverter {
schema: SchemaRef,
}
impl Default for JobsArrowConverter {
fn default() -> Self {
Self::new()
}
}
impl JobsArrowConverter {
pub fn new() -> Self {
let schema = Arc::new(Schema::new(vec![
Field::new("job_id", DataType::Utf8, false),
Field::new(
"created_at",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new("job_meta", DataType::Utf8, false),
Field::new("stages", DataType::Utf8, false),
]));
Self { schema }
}
pub fn schema(&self) -> &SchemaRef {
&self.schema
}
pub fn convert(&self, jobs: &HashMap<StageId, StageInfo>) -> Result<RecordBatch, ArrowError> {
#[derive(Serialize)]
struct StagePayload {
assigned_partitions: HashSet<usize>,
task_set_infos: Vec<TaskSetInfo>,
}
let mut grouped_jobs = BTreeMap::new();
for (stage_id, stage_info) in jobs {
let (_, _, stages) = grouped_jobs
.entry(stage_id.job_id.clone())
.or_insert_with(|| {
(
stage_info.created_at_ms,
stage_info.job_meta.clone(),
BTreeMap::<String, StagePayload>::new(),
)
});
stages.insert(
stage_id.stage.to_string(),
StagePayload {
assigned_partitions: stage_info.assigned_partitions.clone(),
task_set_infos: stage_info.task_set_infos.clone(),
},
);
}
let mut job_id_builder = StringBuilder::new();
let mut created_at_builder = TimestampMillisecondBuilder::new();
let mut job_meta_builder = StringBuilder::new();
let mut stages_builder = StringBuilder::new();
for (job_id, (created_at_ms, job_meta, stages)) in grouped_jobs {
job_id_builder.append_value(job_id.as_ref());
created_at_builder.append_value(created_at_ms);
let job_meta_json = serde_json::to_string_pretty(job_meta.as_ref())
.map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
job_meta_builder.append_value(job_meta_json);
let stages_json = serde_json::to_string_pretty(&stages)
.map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
stages_builder.append_value(stages_json);
}
RecordBatch::try_new(
self.schema.clone(),
vec![
Arc::new(job_id_builder.finish()),
Arc::new(created_at_builder.finish()),
Arc::new(job_meta_builder.finish()),
Arc::new(stages_builder.finish()),
],
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion::prelude::SessionContext;
#[tokio::test]
async fn test_is_plan_select_1() {
let ctx = SessionContext::new();
let df = ctx.sql("SELECT 1").await.unwrap();
let plan = df.create_physical_plan().await.unwrap();
assert!(is_plan_select_1(&plan));
}
}