Skip to main content

datafusion_dist/
util.rs

1use std::{
2    collections::{BTreeMap, HashMap, HashSet},
3    sync::Arc,
4    time::{SystemTime, UNIX_EPOCH},
5};
6
7use arrow::{
8    array::{RecordBatch, StringBuilder, TimestampMillisecondBuilder},
9    datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit},
10    error::ArrowError,
11};
12use datafusion_common::ScalarValue;
13use datafusion_physical_expr::expressions::Literal;
14use datafusion_physical_plan::{
15    ExecutionPlan, placeholder_row::PlaceholderRowExec, projection::ProjectionExec,
16};
17use futures::{StreamExt, stream::BoxStream};
18use serde::Serialize;
19use tokio::{
20    runtime::Handle,
21    sync::mpsc::{Receiver, Sender},
22    task::{AbortHandle, JoinHandle},
23};
24
25use crate::{
26    DistError, DistResult,
27    network::{StageInfo, TaskSetInfo},
28    planner::StageId,
29};
30
31/// Check if the physical plan is a simple `SELECT 1` query.
32/// This is used to identify queries that should be executed locally.
33pub fn is_plan_select_1(plan: &Arc<dyn ExecutionPlan>) -> bool {
34    let Some(proj) = plan.as_any().downcast_ref::<ProjectionExec>() else {
35        return false;
36    };
37    if !proj.input().as_any().is::<PlaceholderRowExec>() {
38        return false;
39    }
40    if proj.expr().len() != 1 {
41        return false;
42    }
43    let expr = &proj.expr()[0];
44    let Some(literal) = expr.expr.as_any().downcast_ref::<Literal>() else {
45        return false;
46    };
47    matches!(
48        literal.value(),
49        ScalarValue::Int32(Some(1)) | ScalarValue::Int64(Some(1))
50    )
51}
52
53pub fn timestamp_ms() -> i64 {
54    SystemTime::now()
55        .duration_since(UNIX_EPOCH)
56        .expect("Time went backwards")
57        .as_millis() as i64
58}
59
60// This function will spawn thread to get the local IP address, so don't call it frequently
61pub fn get_local_ip() -> String {
62    local_ip_address::local_ip()
63        .expect("Failed to get local IP")
64        .to_string()
65}
66
67pub struct ReceiverStreamBuilder<O> {
68    tx: Sender<DistResult<O>>,
69    rx: Receiver<DistResult<O>>,
70    task: Option<JoinHandle<DistResult<()>>>,
71}
72
73impl<O: Send + 'static> ReceiverStreamBuilder<O> {
74    /// Create new channels with the specified buffer size
75    pub fn new(capacity: usize) -> Self {
76        let (tx, rx) = tokio::sync::mpsc::channel(capacity);
77
78        Self { tx, rx, task: None }
79    }
80
81    /// Get a handle for sending data to the output
82    pub fn tx(&self) -> Sender<DistResult<O>> {
83        self.tx.clone()
84    }
85
86    /// Spawn the task on the provided runtime and return a handle for cancellation.
87    pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
88    where
89        F: Future<Output = DistResult<()>>,
90        F: Send + 'static,
91    {
92        assert!(
93            self.task.is_none(),
94            "ReceiverStreamBuilder supports a single task"
95        );
96        let join_handle = handle.spawn(task);
97        let abort_handle = join_handle.abort_handle();
98        self.task = Some(join_handle);
99        abort_handle
100    }
101
102    /// Create a stream of all data written to `tx`
103    pub fn build(self) -> BoxStream<'static, DistResult<O>> {
104        let Self { tx, rx, task } = self;
105
106        // Doesn't need tx
107        drop(tx);
108
109        // Future that checks the spawned task result, and propagates panic if seen.
110        let check = async move {
111            let task = task?;
112
113            match task.await {
114                Ok(Ok(())) => None,
115                Ok(Err(error)) => Some(Err(error)),
116                Err(e) => Some(Err(DistError::internal(format!("Tokio join error: {e}")))),
117            }
118        };
119
120        let check_stream = futures::stream::once(check)
121            // unwrap Option / only return the error
122            .filter_map(|item| async move { item });
123
124        // Convert the receiver into a stream
125        let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
126            let next_item = rx.recv().await;
127            next_item.map(|next_item| (next_item, rx))
128        });
129
130        // Merge the streams together so whichever is ready first
131        // produces the batch
132        futures::stream::select(rx_stream, check_stream).boxed()
133    }
134}
135
136#[derive(Debug)]
137pub struct JobsArrowConverter {
138    schema: SchemaRef,
139}
140
141impl Default for JobsArrowConverter {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147impl JobsArrowConverter {
148    pub fn new() -> Self {
149        let schema = Arc::new(Schema::new(vec![
150            Field::new("job_id", DataType::Utf8, false),
151            Field::new(
152                "created_at",
153                DataType::Timestamp(TimeUnit::Millisecond, None),
154                false,
155            ),
156            Field::new("job_meta", DataType::Utf8, false),
157            Field::new("stages", DataType::Utf8, false),
158        ]));
159        Self { schema }
160    }
161
162    pub fn schema(&self) -> &SchemaRef {
163        &self.schema
164    }
165
166    pub fn convert(&self, jobs: &HashMap<StageId, StageInfo>) -> Result<RecordBatch, ArrowError> {
167        #[derive(Serialize)]
168        struct StagePayload {
169            assigned_partitions: HashSet<usize>,
170            task_set_infos: Vec<TaskSetInfo>,
171        }
172
173        let mut grouped_jobs = BTreeMap::new();
174        for (stage_id, stage_info) in jobs {
175            let (_, _, stages) = grouped_jobs
176                .entry(stage_id.job_id.clone())
177                .or_insert_with(|| {
178                    (
179                        stage_info.created_at_ms,
180                        stage_info.job_meta.clone(),
181                        BTreeMap::<String, StagePayload>::new(),
182                    )
183                });
184            stages.insert(
185                stage_id.stage.to_string(),
186                StagePayload {
187                    assigned_partitions: stage_info.assigned_partitions.clone(),
188                    task_set_infos: stage_info.task_set_infos.clone(),
189                },
190            );
191        }
192
193        let mut job_id_builder = StringBuilder::new();
194        let mut created_at_builder = TimestampMillisecondBuilder::new();
195        let mut job_meta_builder = StringBuilder::new();
196        let mut stages_builder = StringBuilder::new();
197
198        for (job_id, (created_at_ms, job_meta, stages)) in grouped_jobs {
199            job_id_builder.append_value(job_id.as_ref());
200            created_at_builder.append_value(created_at_ms);
201            let job_meta_json = serde_json::to_string_pretty(job_meta.as_ref())
202                .map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
203            job_meta_builder.append_value(job_meta_json);
204            let stages_json = serde_json::to_string_pretty(&stages)
205                .map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
206            stages_builder.append_value(stages_json);
207        }
208
209        RecordBatch::try_new(
210            self.schema.clone(),
211            vec![
212                Arc::new(job_id_builder.finish()),
213                Arc::new(created_at_builder.finish()),
214                Arc::new(job_meta_builder.finish()),
215                Arc::new(stages_builder.finish()),
216            ],
217        )
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use datafusion::prelude::SessionContext;
225
226    #[tokio::test]
227    async fn test_is_plan_select_1() {
228        // Create a DataFusion session context
229        let ctx = SessionContext::new();
230
231        // Execute SQL "SELECT 1" and create physical plan
232        let df = ctx.sql("SELECT 1").await.unwrap();
233        let plan = df.create_physical_plan().await.unwrap();
234
235        // Verify that is_plan_select_1 returns true for this plan
236        assert!(is_plan_select_1(&plan));
237    }
238}