datafusion_dft/tui/
execution.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`AppExecution`]: Handles executing queries for the TUI application.
19
20use crate::execution::AppExecution;
21use crate::tui::AppEvent;
22use color_eyre::eyre::Result;
23use datafusion::arrow::array::RecordBatch;
24#[allow(unused_imports)] // No idea why this is being picked up as unused when I use it twice.
25use datafusion::arrow::error::ArrowError;
26use datafusion::execution::context::SessionContext;
27use datafusion::execution::SendableRecordBatchStream;
28use datafusion::physical_plan::execute_stream;
29use futures::StreamExt;
30use log::{error, info};
31use std::sync::Arc;
32use std::time::Duration;
33use tokio::sync::mpsc::UnboundedSender;
34use tokio::sync::Mutex;
35#[cfg(feature = "flightsql")]
36use tokio_stream::StreamMap;
37
38#[cfg(feature = "flightsql")]
39use {
40    arrow_flight::decode::FlightRecordBatchStream,
41    arrow_flight::sql::client::FlightSqlServiceClient, arrow_flight::Ticket,
42    tonic::transport::Channel, tonic::IntoRequest,
43};
44
45#[derive(Clone, Debug)]
46pub struct ExecutionError {
47    query: String,
48    error: String,
49    duration: Duration,
50}
51
52impl ExecutionError {
53    pub fn new(query: String, error: String, duration: Duration) -> Self {
54        Self {
55            query,
56            error,
57            duration,
58        }
59    }
60
61    pub fn query(&self) -> &str {
62        &self.query
63    }
64
65    pub fn error(&self) -> &str {
66        &self.error
67    }
68
69    pub fn duration(&self) -> &Duration {
70        &self.duration
71    }
72}
73
74#[derive(Clone, Debug)]
75pub struct ExecutionResultsBatch {
76    pub query: String,
77    pub batch: RecordBatch,
78    pub duration: Duration,
79}
80
81impl ExecutionResultsBatch {
82    pub fn new(query: String, batch: RecordBatch, duration: Duration) -> Self {
83        Self {
84            query,
85            batch,
86            duration,
87        }
88    }
89
90    pub fn query(&self) -> &str {
91        &self.query
92    }
93
94    pub fn batch(&self) -> &RecordBatch {
95        &self.batch
96    }
97
98    pub fn duration(&self) -> &Duration {
99        &self.duration
100    }
101}
102
103/// Handles executing queries for the TUI application, formatting results
104/// and sending them to the UI.
105///
106/// TODO: I think we want to store the SQL associated with a stream
107pub struct TuiExecution {
108    inner: Arc<AppExecution>,
109    result_stream: Arc<Mutex<Option<SendableRecordBatchStream>>>,
110    /// StreamMao of FlightSQL streams that could be coming from multiple endpoints / tickets.
111    /// Often times there is only one but we need to be able to handle multiple.  We should test
112    /// this at some point as well.
113    #[cfg(feature = "flightsql")]
114    flightsql_result_stream: Arc<Mutex<Option<StreamMap<String, FlightRecordBatchStream>>>>,
115}
116
117impl TuiExecution {
118    /// Create a new instance of [`AppExecution`].
119    pub fn new(inner: Arc<AppExecution>) -> Self {
120        Self {
121            inner,
122            result_stream: Arc::new(Mutex::new(None)),
123            #[cfg(feature = "flightsql")]
124            flightsql_result_stream: Arc::new(Mutex::new(None)),
125        }
126    }
127
128    pub fn session_ctx(&self) -> &SessionContext {
129        self.inner.session_ctx()
130    }
131
132    pub async fn set_result_stream(&self, stream: SendableRecordBatchStream) {
133        let mut s = self.result_stream.lock().await;
134        *s = Some(stream)
135    }
136
137    #[cfg(feature = "flightsql")]
138    pub async fn set_flightsql_result_stream(
139        &self,
140        ticket: Ticket,
141        stream: FlightRecordBatchStream,
142    ) {
143        let mut s = self.flightsql_result_stream.lock().await;
144        if let Some(ref mut streams) = *s {
145            streams.insert(ticket.to_string(), stream);
146        } else {
147            let mut map: StreamMap<String, FlightRecordBatchStream> = StreamMap::new();
148            let t = ticket.to_string();
149            info!("Adding {t} to FlightSQL streams");
150            map.insert(ticket.to_string(), stream);
151            *s = Some(map);
152        }
153    }
154
155    #[cfg(feature = "flightsql")]
156    pub async fn reset_flightsql_result_stream(&self) {
157        let mut s = self.flightsql_result_stream.lock().await;
158        *s = None;
159    }
160
161    /// Run the sequence of SQL queries, sending the results as
162    /// [`AppEvent::ExecutionResultsBatch`].
163    /// All queries except the last one will have their results discarded.
164    ///
165    /// Error handling: If an error occurs while executing a query, the error is
166    /// logged and execution continues
167    pub async fn run_sqls(
168        self: Arc<Self>,
169        sqls: Vec<String>,
170        sender: UnboundedSender<AppEvent>,
171    ) -> Result<()> {
172        // We need to filter out empty strings to correctly determine the last query for displaying
173        // results.
174        info!("Running sqls: {:?}", sqls);
175        let non_empty_sqls: Vec<String> = sqls.into_iter().filter(|s| !s.is_empty()).collect();
176        info!("Non empty SQLs: {:?}", non_empty_sqls);
177        let statement_count = non_empty_sqls.len();
178        for (i, sql) in non_empty_sqls.into_iter().enumerate() {
179            info!("Running query {}", i);
180            let _sender = sender.clone();
181            let start = std::time::Instant::now();
182            if i == statement_count - 1 {
183                info!("Executing last query and display results");
184                sender.send(AppEvent::NewExecution)?;
185                match self.inner.execution_ctx().create_physical_plan(&sql).await {
186                    Ok(plan) => match execute_stream(plan, self.inner.session_ctx().task_ctx()) {
187                        Ok(stream) => {
188                            self.set_result_stream(stream).await;
189                            let mut stream = self.result_stream.lock().await;
190                            if let Some(s) = stream.as_mut() {
191                                if let Some(b) = s.next().await {
192                                    match b {
193                                        Ok(b) => {
194                                            let duration = start.elapsed();
195                                            let results = ExecutionResultsBatch {
196                                                query: sql.to_string(),
197                                                batch: b,
198                                                duration,
199                                            };
200                                            sender.send(AppEvent::ExecutionResultsNextBatch(
201                                                results,
202                                            ))?;
203                                        }
204                                        Err(e) => {
205                                            error!("Error getting RecordBatch: {:?}", e);
206                                        }
207                                    }
208                                }
209                            }
210                        }
211                        Err(stream_err) => {
212                            error!("Error executing stream: {:?}", stream_err);
213                            let elapsed = start.elapsed();
214                            let e = ExecutionError {
215                                query: sql.to_string(),
216                                error: stream_err.to_string(),
217                                duration: elapsed,
218                            };
219                            sender.send(AppEvent::ExecutionResultsError(e))?;
220                        }
221                    },
222                    Err(plan_err) => {
223                        error!("Error creating physical plan: {:?}", plan_err);
224                        let elapsed = start.elapsed();
225                        let e = ExecutionError {
226                            query: sql.to_string(),
227                            error: plan_err.to_string(),
228                            duration: elapsed,
229                        };
230                        sender.send(AppEvent::ExecutionResultsError(e))?;
231                    }
232                }
233            } else {
234                match self
235                    .inner
236                    .execution_ctx()
237                    .execute_sql_and_discard_results(&sql)
238                    .await
239                {
240                    Ok(_) => {}
241                    Err(e) => {
242                        // We only log failed queries, we don't want to stop the execution of the
243                        // remaining queries. Perhaps there should be a configuration option for
244                        // this though in case the user wants to stop execution on the first error.
245                        error!("Error executing {sql}: {:?}", e);
246                    }
247                }
248            }
249        }
250        Ok(())
251    }
252
253    #[cfg(feature = "flightsql")]
254    pub async fn run_flightsqls(
255        self: Arc<Self>,
256        sqls: Vec<String>,
257        sender: UnboundedSender<AppEvent>,
258    ) -> Result<()> {
259        info!("Running sqls: {:?}", sqls);
260        self.reset_flightsql_result_stream().await;
261        let non_empty_sqls: Vec<String> = sqls.into_iter().filter(|s| !s.is_empty()).collect();
262        let statement_count = non_empty_sqls.len();
263        for (i, sql) in non_empty_sqls.into_iter().enumerate() {
264            let _sender = sender.clone();
265            if i == statement_count - 1 {
266                info!("Executing last query and display results");
267                sender.send(AppEvent::FlightSQLNewExecution)?;
268                if let Some(ref mut client) = *self.flightsql_client().lock().await {
269                    let start = std::time::Instant::now();
270                    match client.execute(sql.clone(), None).await {
271                        Ok(flight_info) => {
272                            for endpoint in flight_info.endpoint {
273                                if let Some(ticket) = endpoint.ticket {
274                                    match client.do_get(ticket.clone().into_request()).await {
275                                        Ok(stream) => {
276                                            self.set_flightsql_result_stream(ticket, stream).await;
277                                            if let Some(streams) =
278                                                self.flightsql_result_stream.lock().await.as_mut()
279                                            {
280                                                match streams.next().await {
281                                                    Some((ticket, Ok(batch))) => {
282                                                        info!("Received batch for {ticket}");
283                                                        let duration = start.elapsed();
284                                                        let results = ExecutionResultsBatch {
285                                                            batch,
286                                                            duration,
287                                                            query: sql.to_string(),
288                                                        };
289                                                        sender.send(
290                                                            AppEvent::FlightSQLExecutionResultsNextBatch(
291                                                                results,
292                                                            ),
293                                                        )?;
294                                                    }
295                                                    Some((ticket, Err(e))) => {
296                                                        error!(
297                                                            "Error executing stream for ticket {ticket}: {:?}",
298                                                            e
299                                                        );
300                                                        let elapsed = start.elapsed();
301                                                        let e = ExecutionError {
302                                                            query: sql.to_string(),
303                                                            error: e.to_string(),
304                                                            duration: elapsed,
305                                                        };
306                                                        sender.send(
307                                                            AppEvent::FlightSQLExecutionResultsError(e),
308                                                        )?;
309                                                    }
310                                                    None => {}
311                                                }
312                                            }
313                                        }
314                                        Err(e) => {
315                                            error!("Error creating result stream: {:?}", e);
316                                            if let ArrowError::IpcError(ipc_err) = &e {
317                                                if ipc_err.contains("error trying to connect") {
318                                                    let e = ExecutionError {
319                                                        query: sql.to_string(),
320                                                        error: "Error connecting to Flight server"
321                                                            .to_string(),
322                                                        duration: std::time::Duration::from_secs(0),
323                                                    };
324                                                    sender.send(
325                                                        AppEvent::FlightSQLExecutionResultsError(e),
326                                                    )?;
327                                                    return Ok(());
328                                                }
329                                            }
330
331                                            let elapsed = start.elapsed();
332                                            let e = ExecutionError {
333                                                query: sql.to_string(),
334                                                error: e.to_string(),
335                                                duration: elapsed,
336                                            };
337                                            sender.send(
338                                                AppEvent::FlightSQLExecutionResultsError(e),
339                                            )?;
340                                        }
341                                    }
342                                }
343                            }
344                        }
345                        Err(e) => {
346                            error!("Error getting flight info: {:?}", e);
347                            if let ArrowError::IpcError(ipc_err) = &e {
348                                if ipc_err.contains("error trying to connect") {
349                                    let e = ExecutionError {
350                                        query: sql.to_string(),
351                                        error: "Error connecting to Flight server".to_string(),
352                                        duration: std::time::Duration::from_secs(0),
353                                    };
354                                    sender.send(AppEvent::FlightSQLExecutionResultsError(e))?;
355                                    return Ok(());
356                                }
357                            }
358                            let elapsed = start.elapsed();
359                            let e = ExecutionError {
360                                query: sql.to_string(),
361                                error: e.to_string(),
362                                duration: elapsed,
363                            };
364                            sender.send(AppEvent::FlightSQLExecutionResultsError(e))?;
365                        }
366                    }
367                } else {
368                    let e = ExecutionError {
369                        query: sql.to_string(),
370                        error: "No FlightSQL client".to_string(),
371                        duration: std::time::Duration::from_secs(0),
372                    };
373                    sender.send(AppEvent::FlightSQLExecutionResultsError(e))?;
374                }
375            }
376        }
377
378        Ok(())
379    }
380
381    pub async fn next_batch(&self, sql: String, sender: UnboundedSender<AppEvent>) {
382        let mut stream = self.result_stream.lock().await;
383        if let Some(s) = stream.as_mut() {
384            let start = std::time::Instant::now();
385            if let Some(b) = s.next().await {
386                match b {
387                    Ok(b) => {
388                        let duration = start.elapsed();
389                        let results = ExecutionResultsBatch {
390                            query: sql,
391                            batch: b,
392                            duration,
393                        };
394                        let _ = sender.send(AppEvent::ExecutionResultsNextBatch(results));
395                    }
396                    Err(e) => {
397                        error!("Error getting RecordBatch: {:?}", e);
398                    }
399                }
400            }
401        }
402    }
403
404    // TODO: Maybe just expose `inner` and use that rather than re-implementing the same
405    // functions here.
406    #[cfg(feature = "flightsql")]
407    pub async fn create_flightsql_client(&self, cli_host: Option<String>) -> Result<()> {
408        self.inner.flightsql_ctx().create_client(cli_host).await
409    }
410
411    #[cfg(feature = "flightsql")]
412    pub fn flightsql_client(&self) -> &Mutex<Option<FlightSqlServiceClient<Channel>>> {
413        self.inner.flightsql_client()
414    }
415
416    pub fn load_ddl(&self) -> Option<String> {
417        self.inner.execution_ctx().load_ddl()
418    }
419
420    pub fn save_ddl(&self, ddl: String) {
421        self.inner.execution_ctx().save_ddl(ddl)
422    }
423}