1use crate::execution::AppExecution;
21use crate::tui::AppEvent;
22use color_eyre::eyre::Result;
23use datafusion::arrow::array::RecordBatch;
24#[allow(unused_imports)] use datafusion::arrow::error::ArrowError;
26use datafusion::execution::context::SessionContext;
27use datafusion::execution::SendableRecordBatchStream;
28use datafusion::physical_plan::execute_stream;
29use futures::StreamExt;
30use log::{error, info};
31use std::sync::Arc;
32use std::time::Duration;
33use tokio::sync::mpsc::UnboundedSender;
34use tokio::sync::Mutex;
35#[cfg(feature = "flightsql")]
36use tokio_stream::StreamMap;
37
38#[cfg(feature = "flightsql")]
39use {
40 arrow_flight::decode::FlightRecordBatchStream,
41 arrow_flight::sql::client::FlightSqlServiceClient, arrow_flight::Ticket,
42 tonic::transport::Channel, tonic::IntoRequest,
43};
44
45#[derive(Clone, Debug)]
46pub struct ExecutionError {
47 query: String,
48 error: String,
49 duration: Duration,
50}
51
52impl ExecutionError {
53 pub fn new(query: String, error: String, duration: Duration) -> Self {
54 Self {
55 query,
56 error,
57 duration,
58 }
59 }
60
61 pub fn query(&self) -> &str {
62 &self.query
63 }
64
65 pub fn error(&self) -> &str {
66 &self.error
67 }
68
69 pub fn duration(&self) -> &Duration {
70 &self.duration
71 }
72}
73
74#[derive(Clone, Debug)]
75pub struct ExecutionResultsBatch {
76 pub query: String,
77 pub batch: RecordBatch,
78 pub duration: Duration,
79}
80
81impl ExecutionResultsBatch {
82 pub fn new(query: String, batch: RecordBatch, duration: Duration) -> Self {
83 Self {
84 query,
85 batch,
86 duration,
87 }
88 }
89
90 pub fn query(&self) -> &str {
91 &self.query
92 }
93
94 pub fn batch(&self) -> &RecordBatch {
95 &self.batch
96 }
97
98 pub fn duration(&self) -> &Duration {
99 &self.duration
100 }
101}
102
103pub struct TuiExecution {
108 inner: Arc<AppExecution>,
109 result_stream: Arc<Mutex<Option<SendableRecordBatchStream>>>,
110 #[cfg(feature = "flightsql")]
114 flightsql_result_stream: Arc<Mutex<Option<StreamMap<String, FlightRecordBatchStream>>>>,
115}
116
117impl TuiExecution {
118 pub fn new(inner: Arc<AppExecution>) -> Self {
120 Self {
121 inner,
122 result_stream: Arc::new(Mutex::new(None)),
123 #[cfg(feature = "flightsql")]
124 flightsql_result_stream: Arc::new(Mutex::new(None)),
125 }
126 }
127
128 pub fn session_ctx(&self) -> &SessionContext {
129 self.inner.session_ctx()
130 }
131
132 pub async fn set_result_stream(&self, stream: SendableRecordBatchStream) {
133 let mut s = self.result_stream.lock().await;
134 *s = Some(stream)
135 }
136
137 #[cfg(feature = "flightsql")]
138 pub async fn set_flightsql_result_stream(
139 &self,
140 ticket: Ticket,
141 stream: FlightRecordBatchStream,
142 ) {
143 let mut s = self.flightsql_result_stream.lock().await;
144 if let Some(ref mut streams) = *s {
145 streams.insert(ticket.to_string(), stream);
146 } else {
147 let mut map: StreamMap<String, FlightRecordBatchStream> = StreamMap::new();
148 let t = ticket.to_string();
149 info!("Adding {t} to FlightSQL streams");
150 map.insert(ticket.to_string(), stream);
151 *s = Some(map);
152 }
153 }
154
155 #[cfg(feature = "flightsql")]
156 pub async fn reset_flightsql_result_stream(&self) {
157 let mut s = self.flightsql_result_stream.lock().await;
158 *s = None;
159 }
160
161 pub async fn run_sqls(
168 self: Arc<Self>,
169 sqls: Vec<String>,
170 sender: UnboundedSender<AppEvent>,
171 ) -> Result<()> {
172 info!("Running sqls: {:?}", sqls);
175 let non_empty_sqls: Vec<String> = sqls.into_iter().filter(|s| !s.is_empty()).collect();
176 info!("Non empty SQLs: {:?}", non_empty_sqls);
177 let statement_count = non_empty_sqls.len();
178 for (i, sql) in non_empty_sqls.into_iter().enumerate() {
179 info!("Running query {}", i);
180 let _sender = sender.clone();
181 let start = std::time::Instant::now();
182 if i == statement_count - 1 {
183 info!("Executing last query and display results");
184 sender.send(AppEvent::NewExecution)?;
185 match self.inner.execution_ctx().create_physical_plan(&sql).await {
186 Ok(plan) => match execute_stream(plan, self.inner.session_ctx().task_ctx()) {
187 Ok(stream) => {
188 self.set_result_stream(stream).await;
189 let mut stream = self.result_stream.lock().await;
190 if let Some(s) = stream.as_mut() {
191 if let Some(b) = s.next().await {
192 match b {
193 Ok(b) => {
194 let duration = start.elapsed();
195 let results = ExecutionResultsBatch {
196 query: sql.to_string(),
197 batch: b,
198 duration,
199 };
200 sender.send(AppEvent::ExecutionResultsNextBatch(
201 results,
202 ))?;
203 }
204 Err(e) => {
205 error!("Error getting RecordBatch: {:?}", e);
206 }
207 }
208 }
209 }
210 }
211 Err(stream_err) => {
212 error!("Error executing stream: {:?}", stream_err);
213 let elapsed = start.elapsed();
214 let e = ExecutionError {
215 query: sql.to_string(),
216 error: stream_err.to_string(),
217 duration: elapsed,
218 };
219 sender.send(AppEvent::ExecutionResultsError(e))?;
220 }
221 },
222 Err(plan_err) => {
223 error!("Error creating physical plan: {:?}", plan_err);
224 let elapsed = start.elapsed();
225 let e = ExecutionError {
226 query: sql.to_string(),
227 error: plan_err.to_string(),
228 duration: elapsed,
229 };
230 sender.send(AppEvent::ExecutionResultsError(e))?;
231 }
232 }
233 } else {
234 match self
235 .inner
236 .execution_ctx()
237 .execute_sql_and_discard_results(&sql)
238 .await
239 {
240 Ok(_) => {}
241 Err(e) => {
242 error!("Error executing {sql}: {:?}", e);
246 }
247 }
248 }
249 }
250 Ok(())
251 }
252
253 #[cfg(feature = "flightsql")]
254 pub async fn run_flightsqls(
255 self: Arc<Self>,
256 sqls: Vec<String>,
257 sender: UnboundedSender<AppEvent>,
258 ) -> Result<()> {
259 info!("Running sqls: {:?}", sqls);
260 self.reset_flightsql_result_stream().await;
261 let non_empty_sqls: Vec<String> = sqls.into_iter().filter(|s| !s.is_empty()).collect();
262 let statement_count = non_empty_sqls.len();
263 for (i, sql) in non_empty_sqls.into_iter().enumerate() {
264 let _sender = sender.clone();
265 if i == statement_count - 1 {
266 info!("Executing last query and display results");
267 sender.send(AppEvent::FlightSQLNewExecution)?;
268 if let Some(ref mut client) = *self.flightsql_client().lock().await {
269 let start = std::time::Instant::now();
270 match client.execute(sql.clone(), None).await {
271 Ok(flight_info) => {
272 for endpoint in flight_info.endpoint {
273 if let Some(ticket) = endpoint.ticket {
274 match client.do_get(ticket.clone().into_request()).await {
275 Ok(stream) => {
276 self.set_flightsql_result_stream(ticket, stream).await;
277 if let Some(streams) =
278 self.flightsql_result_stream.lock().await.as_mut()
279 {
280 match streams.next().await {
281 Some((ticket, Ok(batch))) => {
282 info!("Received batch for {ticket}");
283 let duration = start.elapsed();
284 let results = ExecutionResultsBatch {
285 batch,
286 duration,
287 query: sql.to_string(),
288 };
289 sender.send(
290 AppEvent::FlightSQLExecutionResultsNextBatch(
291 results,
292 ),
293 )?;
294 }
295 Some((ticket, Err(e))) => {
296 error!(
297 "Error executing stream for ticket {ticket}: {:?}",
298 e
299 );
300 let elapsed = start.elapsed();
301 let e = ExecutionError {
302 query: sql.to_string(),
303 error: e.to_string(),
304 duration: elapsed,
305 };
306 sender.send(
307 AppEvent::FlightSQLExecutionResultsError(e),
308 )?;
309 }
310 None => {}
311 }
312 }
313 }
314 Err(e) => {
315 error!("Error creating result stream: {:?}", e);
316 if let ArrowError::IpcError(ipc_err) = &e {
317 if ipc_err.contains("error trying to connect") {
318 let e = ExecutionError {
319 query: sql.to_string(),
320 error: "Error connecting to Flight server"
321 .to_string(),
322 duration: std::time::Duration::from_secs(0),
323 };
324 sender.send(
325 AppEvent::FlightSQLExecutionResultsError(e),
326 )?;
327 return Ok(());
328 }
329 }
330
331 let elapsed = start.elapsed();
332 let e = ExecutionError {
333 query: sql.to_string(),
334 error: e.to_string(),
335 duration: elapsed,
336 };
337 sender.send(
338 AppEvent::FlightSQLExecutionResultsError(e),
339 )?;
340 }
341 }
342 }
343 }
344 }
345 Err(e) => {
346 error!("Error getting flight info: {:?}", e);
347 if let ArrowError::IpcError(ipc_err) = &e {
348 if ipc_err.contains("error trying to connect") {
349 let e = ExecutionError {
350 query: sql.to_string(),
351 error: "Error connecting to Flight server".to_string(),
352 duration: std::time::Duration::from_secs(0),
353 };
354 sender.send(AppEvent::FlightSQLExecutionResultsError(e))?;
355 return Ok(());
356 }
357 }
358 let elapsed = start.elapsed();
359 let e = ExecutionError {
360 query: sql.to_string(),
361 error: e.to_string(),
362 duration: elapsed,
363 };
364 sender.send(AppEvent::FlightSQLExecutionResultsError(e))?;
365 }
366 }
367 } else {
368 let e = ExecutionError {
369 query: sql.to_string(),
370 error: "No FlightSQL client".to_string(),
371 duration: std::time::Duration::from_secs(0),
372 };
373 sender.send(AppEvent::FlightSQLExecutionResultsError(e))?;
374 }
375 }
376 }
377
378 Ok(())
379 }
380
381 pub async fn next_batch(&self, sql: String, sender: UnboundedSender<AppEvent>) {
382 let mut stream = self.result_stream.lock().await;
383 if let Some(s) = stream.as_mut() {
384 let start = std::time::Instant::now();
385 if let Some(b) = s.next().await {
386 match b {
387 Ok(b) => {
388 let duration = start.elapsed();
389 let results = ExecutionResultsBatch {
390 query: sql,
391 batch: b,
392 duration,
393 };
394 let _ = sender.send(AppEvent::ExecutionResultsNextBatch(results));
395 }
396 Err(e) => {
397 error!("Error getting RecordBatch: {:?}", e);
398 }
399 }
400 }
401 }
402 }
403
404 #[cfg(feature = "flightsql")]
407 pub async fn create_flightsql_client(&self, cli_host: Option<String>) -> Result<()> {
408 self.inner.flightsql_ctx().create_client(cli_host).await
409 }
410
411 #[cfg(feature = "flightsql")]
412 pub fn flightsql_client(&self) -> &Mutex<Option<FlightSqlServiceClient<Channel>>> {
413 self.inner.flightsql_client()
414 }
415
416 pub fn load_ddl(&self) -> Option<String> {
417 self.inner.execution_ctx().load_ddl()
418 }
419
420 pub fn save_ddl(&self, ddl: String) {
421 self.inner.execution_ctx().save_ddl(ddl)
422 }
423}