1use std::{
2 collections::{BTreeMap, HashMap, HashSet},
3 sync::Arc,
4 time::{SystemTime, UNIX_EPOCH},
5};
6
7use arrow::{
8 array::{RecordBatch, StringBuilder, TimestampMillisecondBuilder},
9 datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit},
10 error::ArrowError,
11};
12use datafusion_common::ScalarValue;
13use datafusion_physical_expr::expressions::Literal;
14use datafusion_physical_plan::{
15 ExecutionPlan, placeholder_row::PlaceholderRowExec, projection::ProjectionExec,
16};
17use futures::{StreamExt, stream::BoxStream};
18use serde::Serialize;
19use tokio::{
20 runtime::Handle,
21 sync::mpsc::{Receiver, Sender},
22 task::{AbortHandle, JoinHandle},
23};
24
25use crate::{
26 DistError, DistResult,
27 network::{StageInfo, TaskSetInfo},
28 planner::StageId,
29};
30
31pub fn is_plan_select_1(plan: &Arc<dyn ExecutionPlan>) -> bool {
34 let Some(proj) = plan.as_any().downcast_ref::<ProjectionExec>() else {
35 return false;
36 };
37 if !proj.input().as_any().is::<PlaceholderRowExec>() {
38 return false;
39 }
40 if proj.expr().len() != 1 {
41 return false;
42 }
43 let expr = &proj.expr()[0];
44 let Some(literal) = expr.expr.as_any().downcast_ref::<Literal>() else {
45 return false;
46 };
47 matches!(
48 literal.value(),
49 ScalarValue::Int32(Some(1)) | ScalarValue::Int64(Some(1))
50 )
51}
52
53pub fn timestamp_ms() -> i64 {
54 SystemTime::now()
55 .duration_since(UNIX_EPOCH)
56 .expect("Time went backwards")
57 .as_millis() as i64
58}
59
60pub fn get_local_ip() -> String {
62 local_ip_address::local_ip()
63 .expect("Failed to get local IP")
64 .to_string()
65}
66
67pub struct ReceiverStreamBuilder<O> {
68 tx: Sender<DistResult<O>>,
69 rx: Receiver<DistResult<O>>,
70 task: Option<JoinHandle<DistResult<()>>>,
71}
72
73impl<O: Send + 'static> ReceiverStreamBuilder<O> {
74 pub fn new(capacity: usize) -> Self {
76 let (tx, rx) = tokio::sync::mpsc::channel(capacity);
77
78 Self { tx, rx, task: None }
79 }
80
81 pub fn tx(&self) -> Sender<DistResult<O>> {
83 self.tx.clone()
84 }
85
86 pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
88 where
89 F: Future<Output = DistResult<()>>,
90 F: Send + 'static,
91 {
92 assert!(
93 self.task.is_none(),
94 "ReceiverStreamBuilder supports a single task"
95 );
96 let join_handle = handle.spawn(task);
97 let abort_handle = join_handle.abort_handle();
98 self.task = Some(join_handle);
99 abort_handle
100 }
101
102 pub fn build(self) -> BoxStream<'static, DistResult<O>> {
104 let Self { tx, rx, task } = self;
105
106 drop(tx);
108
109 let check = async move {
111 let task = task?;
112
113 match task.await {
114 Ok(Ok(())) => None,
115 Ok(Err(error)) => Some(Err(error)),
116 Err(e) => Some(Err(DistError::internal(format!("Tokio join error: {e}")))),
117 }
118 };
119
120 let check_stream = futures::stream::once(check)
121 .filter_map(|item| async move { item });
123
124 let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
126 let next_item = rx.recv().await;
127 next_item.map(|next_item| (next_item, rx))
128 });
129
130 futures::stream::select(rx_stream, check_stream).boxed()
133 }
134}
135
136#[derive(Debug)]
137pub struct JobsArrowConverter {
138 schema: SchemaRef,
139}
140
141impl Default for JobsArrowConverter {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147impl JobsArrowConverter {
148 pub fn new() -> Self {
149 let schema = Arc::new(Schema::new(vec![
150 Field::new("job_id", DataType::Utf8, false),
151 Field::new(
152 "created_at",
153 DataType::Timestamp(TimeUnit::Millisecond, None),
154 false,
155 ),
156 Field::new("job_meta", DataType::Utf8, false),
157 Field::new("stages", DataType::Utf8, false),
158 ]));
159 Self { schema }
160 }
161
162 pub fn schema(&self) -> &SchemaRef {
163 &self.schema
164 }
165
166 pub fn convert(&self, jobs: &HashMap<StageId, StageInfo>) -> Result<RecordBatch, ArrowError> {
167 #[derive(Serialize)]
168 struct StagePayload {
169 assigned_partitions: HashSet<usize>,
170 task_set_infos: Vec<TaskSetInfo>,
171 }
172
173 let mut grouped_jobs = BTreeMap::new();
174 for (stage_id, stage_info) in jobs {
175 let (_, _, stages) = grouped_jobs
176 .entry(stage_id.job_id.clone())
177 .or_insert_with(|| {
178 (
179 stage_info.created_at_ms,
180 stage_info.job_meta.clone(),
181 BTreeMap::<String, StagePayload>::new(),
182 )
183 });
184 stages.insert(
185 stage_id.stage.to_string(),
186 StagePayload {
187 assigned_partitions: stage_info.assigned_partitions.clone(),
188 task_set_infos: stage_info.task_set_infos.clone(),
189 },
190 );
191 }
192
193 let mut job_id_builder = StringBuilder::new();
194 let mut created_at_builder = TimestampMillisecondBuilder::new();
195 let mut job_meta_builder = StringBuilder::new();
196 let mut stages_builder = StringBuilder::new();
197
198 for (job_id, (created_at_ms, job_meta, stages)) in grouped_jobs {
199 job_id_builder.append_value(job_id.as_ref());
200 created_at_builder.append_value(created_at_ms);
201 let job_meta_json = serde_json::to_string_pretty(job_meta.as_ref())
202 .map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
203 job_meta_builder.append_value(job_meta_json);
204 let stages_json = serde_json::to_string_pretty(&stages)
205 .map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
206 stages_builder.append_value(stages_json);
207 }
208
209 RecordBatch::try_new(
210 self.schema.clone(),
211 vec![
212 Arc::new(job_id_builder.finish()),
213 Arc::new(created_at_builder.finish()),
214 Arc::new(job_meta_builder.finish()),
215 Arc::new(stages_builder.finish()),
216 ],
217 )
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use datafusion::prelude::SessionContext;
225
226 #[tokio::test]
227 async fn test_is_plan_select_1() {
228 let ctx = SessionContext::new();
230
231 let df = ctx.sql("SELECT 1").await.unwrap();
233 let plan = df.create_physical_plan().await.unwrap();
234
235 assert!(is_plan_select_1(&plan));
237 }
238}