kapot_scheduler/
flight_sql.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow_flight::flight_descriptor::DescriptorType;
19use arrow_flight::flight_service_server::FlightService;
20use arrow_flight::sql::server::{FlightSqlService, PeekableFlightDataStream};
21use arrow_flight::sql::{
22    ActionBeginSavepointRequest, ActionBeginSavepointResult,
23    ActionBeginTransactionRequest, ActionBeginTransactionResult,
24    ActionCancelQueryRequest, ActionCancelQueryResult,
25    ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
26    ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest,
27    ActionEndSavepointRequest, ActionEndTransactionRequest, CommandGetCatalogs,
28    CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys,
29    CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo,
30    CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo,
31    CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery,
32    CommandStatementSubstraitPlan, CommandStatementUpdate, DoPutPreparedStatementResult,
33    SqlInfo, TicketStatementQuery,
34};
35use arrow_flight::{
36    Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest,
37    HandshakeResponse, Ticket,
38};
39use base64::Engine;
40use futures::Stream;
41use log::{debug, error, warn};
42use std::convert::TryFrom;
43use std::pin::Pin;
44use std::str::FromStr;
45use std::string::ToString;
46use std::sync::Arc;
47use std::time::Duration;
48use tonic::{Request, Response, Status, Streaming};
49
50use crate::scheduler_server::SchedulerServer;
51use arrow_flight::flight_service_client::FlightServiceClient;
52use arrow_flight::sql::ProstMessageExt;
53use arrow_flight::utils::batches_to_flight_data;
54use arrow_flight::SchemaAsIpc;
55use kapot_core::config::KapotConfig;
56use kapot_core::serde::protobuf;
57use kapot_core::serde::protobuf::action::ActionType::FetchPartition;
58use kapot_core::serde::protobuf::job_status;
59use kapot_core::serde::protobuf::JobStatus;
60use kapot_core::serde::protobuf::SuccessfulJob;
61use kapot_core::utils::create_grpc_client_connection;
62use dashmap::DashMap;
63use datafusion::arrow;
64use datafusion::arrow::array::{ArrayRef, StringArray};
65use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
66use datafusion::arrow::error::ArrowError;
67use datafusion::arrow::ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
68use datafusion::arrow::record_batch::RecordBatch;
69use datafusion::common::DFSchemaRef;
70use datafusion::logical_expr::LogicalPlan;
71use datafusion::prelude::SessionContext;
72use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode};
73use prost::bytes::Bytes;
74use prost::Message;
75use tokio::sync::mpsc::{channel, Receiver, Sender};
76use tokio::time::sleep;
77use tokio_stream::wrappers::ReceiverStream;
78use tonic::metadata::MetadataValue;
79use uuid::Uuid;
80
81pub struct FlightSqlServiceImpl {
82    server: SchedulerServer<LogicalPlanNode, PhysicalPlanNode>,
83    statements: Arc<DashMap<Uuid, LogicalPlan>>,
84    contexts: Arc<DashMap<Uuid, Arc<SessionContext>>>,
85}
86
87const TABLE_TYPES: [&str; 2] = ["TABLE", "VIEW"];
88
89impl FlightSqlServiceImpl {
90    pub fn new(server: SchedulerServer<LogicalPlanNode, PhysicalPlanNode>) -> Self {
91        Self {
92            server,
93            statements: Default::default(),
94            contexts: Default::default(),
95        }
96    }
97
98    fn tables(&self, ctx: Arc<SessionContext>) -> Result<RecordBatch, ArrowError> {
99        let schema = Arc::new(Schema::new(vec![
100            Field::new("catalog_name", DataType::Utf8, true),
101            Field::new("db_schema_name", DataType::Utf8, true),
102            Field::new("table_name", DataType::Utf8, false),
103            Field::new("table_type", DataType::Utf8, false),
104        ]));
105        let mut names: Vec<Option<String>> = vec![];
106        for catalog_name in ctx.catalog_names() {
107            let catalog = ctx
108                .catalog(&catalog_name)
109                .expect("catalog should have been found");
110            for schema_name in catalog.schema_names() {
111                let schema = catalog
112                    .schema(&schema_name)
113                    .expect("schema should have been found");
114                for table_name in schema.table_names() {
115                    names.push(Some(table_name));
116                }
117            }
118        }
119        let types: Vec<_> = names.iter().map(|_| Some("TABLE".to_string())).collect();
120        let cats: Vec<_> = names.iter().map(|_| None).collect();
121        let schemas: Vec<_> = names.iter().map(|_| None).collect();
122        let rb = RecordBatch::try_new(
123            schema,
124            [cats, schemas, names, types]
125                .iter()
126                .map(|i| Arc::new(StringArray::from(i.clone())) as ArrayRef)
127                .collect::<Vec<_>>(),
128        )?;
129        Ok(rb)
130    }
131
132    fn table_types() -> Result<RecordBatch, ArrowError> {
133        let schema = Arc::new(Schema::new(vec![Field::new(
134            "table_type",
135            DataType::Utf8,
136            false,
137        )]));
138        RecordBatch::try_new(
139            schema,
140            [TABLE_TYPES]
141                .iter()
142                .map(|i| Arc::new(StringArray::from(i.to_vec())) as ArrayRef)
143                .collect::<Vec<_>>(),
144        )
145    }
146
147    async fn create_ctx(&self) -> Result<Uuid, Status> {
148        let config_builder = KapotConfig::builder();
149        let config = config_builder
150            .build()
151            .map_err(|e| Status::internal(format!("Error building config: {e}")))?;
152        let ctx = self
153            .server
154            .state
155            .session_manager
156            .create_session(&config)
157            .await
158            .map_err(|e| {
159                Status::internal(format!("Failed to create SessionContext: {e:?}"))
160            })?;
161        let handle = Uuid::new_v4();
162        self.contexts.insert(handle, ctx);
163        Ok(handle)
164    }
165
166    fn get_ctx<T>(&self, req: &Request<T>) -> Result<Arc<SessionContext>, Status> {
167        let auth = req
168            .metadata()
169            .get("authorization")
170            .ok_or_else(|| Status::internal("No authorization header!"))?;
171        let str = auth
172            .to_str()
173            .map_err(|e| Status::internal(format!("Error parsing header: {e}")))?;
174        let authorization = str.to_string();
175        let bearer = "Bearer ";
176        if !authorization.starts_with(bearer) {
177            Err(Status::internal("Invalid auth header!"))?;
178        }
179        let auth = authorization[bearer.len()..].to_string();
180
181        let handle = Uuid::from_str(auth.as_str())
182            .map_err(|e| Status::internal(format!("Error locking contexts: {e}")))?;
183        if let Some(context) = self.contexts.get(&handle) {
184            Ok(context.clone())
185        } else {
186            Err(Status::internal(format!(
187                "Context handle not found: {handle}"
188            )))?
189        }
190    }
191
192    async fn prepare_statement(
193        query: &str,
194        ctx: &Arc<SessionContext>,
195    ) -> Result<LogicalPlan, Status> {
196        let plan = ctx
197            .sql(query)
198            .await
199            .and_then(|df| df.into_optimized_plan())
200            .map_err(|e| Status::internal(format!("Error building plan: {e}")))?;
201        Ok(plan)
202    }
203
204    async fn check_job(&self, job_id: &String) -> Result<Option<SuccessfulJob>, Status> {
205        let status = self
206            .server
207            .state
208            .task_manager
209            .get_job_status(job_id)
210            .await
211            .map_err(|e| {
212                let msg = format!("Error getting status for job {job_id}: {e:?}");
213                error!("{}", msg);
214                Status::internal(msg)
215            })?;
216        let status: JobStatus = match status {
217            Some(status) => status,
218            None => {
219                let msg = format!("Error getting status for job {job_id}!");
220                error!("{}", msg);
221                Err(Status::internal(msg))?
222            }
223        };
224        let status: job_status::Status = match status.status {
225            Some(status) => status,
226            None => {
227                let msg = format!("Error getting status for job {job_id}!");
228                error!("{}", msg);
229                Err(Status::internal(msg))?
230            }
231        };
232        match status {
233            job_status::Status::Queued(_) => Ok(None),
234            job_status::Status::Running(_) => Ok(None),
235            job_status::Status::Failed(e) => {
236                warn!("Error executing plan: {:?}", e);
237                Err(Status::internal(format!(
238                    "Error executing plan: {}",
239                    e.error
240                )))?
241            }
242            job_status::Status::Successful(comp) => Ok(Some(comp)),
243        }
244    }
245
246    async fn job_to_fetch_part(
247        &self,
248        completed: SuccessfulJob,
249        num_rows: &mut i64,
250        num_bytes: &mut i64,
251    ) -> Result<Vec<FlightEndpoint>, Status> {
252        let mut fieps: Vec<_> = vec![];
253        for loc in completed.partition_location.iter() {
254            let (exec_host, exec_port) = if let Some(ref md) = loc.executor_meta {
255                (md.host.clone(), md.port)
256            } else {
257                Err(Status::internal(
258                    "Invalid partition location, missing executor metadata and advertise_endpoint flag is undefined.".to_string(),
259                ))?
260            };
261
262            let (host, port) = match &self
263                .server
264                .state
265                .config
266                .advertise_flight_sql_endpoint
267            {
268                Some(endpoint) => {
269                    let advertise_endpoint_vec: Vec<&str> = endpoint.split(':').collect();
270                    match advertise_endpoint_vec.as_slice() {
271                        [host_ip, port] => {
272                            (String::from(*host_ip), FromStr::from_str(port).expect("Failed to parse port from advertise-endpoint."))
273                        }
274                        _ => {
275                            Err(Status::internal("advertise-endpoint flag has incorrect format. Expected IP:Port".to_string()))?
276                        }
277                    }
278                }
279                None => (exec_host.clone(), exec_port),
280            };
281
282            let fetch = if let Some(ref id) = loc.partition_id {
283                let fetch = protobuf::FetchPartition {
284                    job_id: id.job_id.clone(),
285                    stage_id: id.stage_id,
286                    partition_id: id.partition_id,
287                    path: loc.path.clone(),
288                    // Use executor ip:port for routing to flight result
289                    host: exec_host.clone(),
290                    port: exec_port,
291                };
292                protobuf::Action {
293                    action_type: Some(FetchPartition(fetch)),
294                    settings: vec![],
295                }
296            } else {
297                Err(Status::internal("Error getting partition ID".to_string()))?
298            };
299            if let Some(ref stats) = loc.partition_stats {
300                *num_rows += stats.num_rows;
301                *num_bytes += stats.num_bytes;
302            } else {
303                Err(Status::internal("Error getting stats".to_string()))?
304            }
305            let authority = format!("{}:{}", &host, &port);
306            let buf = fetch.as_any().encode_to_vec();
307            let ticket = Ticket { ticket: buf.into() };
308            let fiep = FlightEndpoint::new()
309                .with_ticket(ticket)
310                .with_location(format!("grpc+tcp://{authority}"));
311            fieps.push(fiep);
312        }
313        Ok(fieps)
314    }
315
316    fn make_local_fieps(&self, job_id: &str) -> Result<Vec<FlightEndpoint>, Status> {
317        let (host, port) = ("127.0.0.1".to_string(), 50050); // TODO: use advertise host
318        let fetch = protobuf::FetchPartition {
319            job_id: job_id.to_string(),
320            stage_id: 0,
321            partition_id: 0,
322            path: job_id.to_string(),
323            host: host.clone(),
324            port,
325        };
326        let fetch = protobuf::Action {
327            action_type: Some(FetchPartition(fetch)),
328            settings: vec![],
329        };
330        let authority = format!("{}:{}", &host, &port); // TODO: use advertise host
331        let buf = fetch.as_any().encode_to_vec();
332        let ticket = Ticket { ticket: buf.into() };
333        let fiep = FlightEndpoint::new()
334            .with_ticket(ticket)
335            .with_location(format!("grpc+tcp://{authority}"));
336        let fieps = vec![fiep];
337        Ok(fieps)
338    }
339
340    fn cache_plan(&self, plan: LogicalPlan) -> Result<Uuid, Status> {
341        let handle = Uuid::new_v4();
342        self.statements.insert(handle, plan);
343        Ok(handle)
344    }
345
346    fn get_plan(&self, handle: &Uuid) -> Result<LogicalPlan, Status> {
347        if let Some(plan) = self.statements.get(handle) {
348            Ok(plan.clone())
349        } else {
350            Err(Status::internal(format!(
351                "Statement handle not found: {handle}"
352            )))?
353        }
354    }
355
356    fn remove_plan(&self, handle: Uuid) -> Result<(), Status> {
357        self.statements.remove(&handle);
358        Ok(())
359    }
360
361    fn df_schema_to_arrow(&self, schema: &DFSchemaRef) -> Result<Vec<u8>, Status> {
362        let arrow_schema: Schema = (&**schema).into();
363        let schema_bytes = self.schema_to_arrow(Arc::new(arrow_schema))?;
364        Ok(schema_bytes)
365    }
366
367    fn schema_to_arrow(&self, arrow_schema: SchemaRef) -> Result<Vec<u8>, Status> {
368        let options = IpcWriteOptions::default();
369        let pair = SchemaAsIpc::new(&arrow_schema, &options);
370        let data_gen = IpcDataGenerator::default();
371        let schema = pair.0;
372        let write_options = pair.1;
373        let mut dictionary_tracker = DictionaryTracker::new(false);
374        let encoded_data = data_gen.schema_to_bytes_with_dictionary_tracker(schema, &mut dictionary_tracker, &write_options);
375        let mut schema_bytes = vec![];
376        arrow::ipc::writer::write_message(&mut schema_bytes, encoded_data, pair.1)
377            .map_err(|e| Status::internal(format!("Error encoding schema: {e}")))?;
378        Ok(schema_bytes)
379    }
380
381    async fn enqueue_job(
382        &self,
383        ctx: Arc<SessionContext>,
384        plan: &LogicalPlan,
385    ) -> Result<String, Status> {
386        let job_id = self.server.state.task_manager.generate_job_id();
387        let job_name = format!("Flight SQL job {job_id}");
388        self.server
389            .submit_job(&job_id, &job_name, ctx, plan)
390            .await
391            .map_err(|e| {
392                let msg = format!("Failed to send JobQueued event for {job_id}: {e:?}");
393                error!("{}", msg);
394                Status::internal(msg)
395            })?;
396        Ok(job_id)
397    }
398
399    fn create_resp(
400        schema_bytes: Vec<u8>,
401        fieps: Vec<FlightEndpoint>,
402        num_rows: i64,
403        num_bytes: i64,
404    ) -> Response<FlightInfo> {
405        let flight_desc = FlightDescriptor {
406            r#type: DescriptorType::Cmd.into(),
407            cmd: Vec::new().into(),
408            path: vec![],
409        };
410        let info = FlightInfo {
411            schema: schema_bytes.into(),
412            flight_descriptor: Some(flight_desc),
413            endpoint: fieps,
414            total_records: num_rows,
415            total_bytes: num_bytes,
416            ordered: false,
417            app_metadata: Bytes::new(),
418        };
419        Response::new(info)
420    }
421
422    async fn execute_plan(
423        &self,
424        ctx: Arc<SessionContext>,
425        plan: &LogicalPlan,
426    ) -> Result<Response<FlightInfo>, Status> {
427        let job_id = self.enqueue_job(ctx, plan).await?;
428
429        // poll for job completion
430        let mut num_rows = 0;
431        let mut num_bytes = 0;
432        let fieps = loop {
433            sleep(Duration::from_millis(100)).await;
434            let completed = if let Some(comp) = self.check_job(&job_id).await? {
435                comp
436            } else {
437                continue;
438            };
439            let fieps = self
440                .job_to_fetch_part(completed, &mut num_rows, &mut num_bytes)
441                .await?;
442            break fieps;
443        };
444
445        // Generate response
446        let schema_bytes = self.df_schema_to_arrow(plan.schema())?;
447        let resp = Self::create_resp(schema_bytes, fieps, num_rows, num_bytes);
448        Ok(resp)
449    }
450
451    async fn record_batch_to_resp(
452        rb: RecordBatch,
453    ) -> Result<
454        Response<Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send>>>,
455        Status,
456    > {
457        type FlightResult = Result<FlightData, Status>;
458        let (tx, rx): (Sender<FlightResult>, Receiver<FlightResult>) = channel(2);
459        let schema = rb.schema();
460        let flights = batches_to_flight_data(&schema, vec![rb])
461            .map_err(|_| Status::internal("Error encoding batches".to_string()))?;
462        for flight in flights {
463            tx.send(Ok(flight))
464                .await
465                .map_err(|_| Status::internal("Error sending flight".to_string()))?;
466        }
467        let resp = Response::new(Box::pin(ReceiverStream::new(rx))
468            as Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send + 'static>>);
469        Ok(resp)
470    }
471
472    fn batch_to_schema_resp(
473        &self,
474        data: &RecordBatch,
475        name: &str,
476    ) -> Result<Response<FlightInfo>, Status> {
477        let num_bytes = data.get_array_memory_size() as i64;
478        let schema = data.schema();
479        let num_rows = data.num_rows() as i64;
480
481        let fieps = self.make_local_fieps(name)?;
482        let schema_bytes = self.schema_to_arrow(schema)?;
483        let resp = Self::create_resp(schema_bytes, fieps, num_rows, num_bytes);
484        Ok(resp)
485    }
486}
487
488#[tonic::async_trait]
489impl FlightSqlService for FlightSqlServiceImpl {
490    type FlightService = FlightSqlServiceImpl;
491
492    async fn do_handshake(
493        &self,
494        request: Request<Streaming<HandshakeRequest>>,
495    ) -> Result<
496        Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send>>>,
497        Status,
498    > {
499        debug!("do_handshake");
500        for md in request.metadata().iter() {
501            debug!("{:?}", md);
502        }
503
504        let basic = "Basic ";
505        let authorization = request
506            .metadata()
507            .get("authorization")
508            .ok_or_else(|| Status::invalid_argument("authorization field not present"))?
509            .to_str()
510            .map_err(|_| Status::invalid_argument("authorization not parsable"))?;
511        if !authorization.starts_with(basic) {
512            Err(Status::invalid_argument(format!(
513                "Auth type not implemented: {authorization}"
514            )))?;
515        }
516        let bytes = base64::engine::general_purpose::STANDARD
517            .decode(&authorization[basic.len()..])
518            .map_err(|_| Status::invalid_argument("authorization not parsable"))?;
519        let str = String::from_utf8(bytes)
520            .map_err(|_| Status::invalid_argument("authorization not parsable"))?;
521        let parts: Vec<_> = str.split(':').collect();
522        if parts.len() != 2 {
523            Err(Status::invalid_argument("Invalid authorization header"))?;
524        }
525        let user = parts[0];
526        let pass = parts[1];
527        if user != "admin" || pass != "password" {
528            Err(Status::unauthenticated("Invalid credentials!"))?
529        }
530
531        let token = self.create_ctx().await?;
532
533        let result = HandshakeResponse {
534            protocol_version: 0,
535            payload: token.as_bytes().to_vec().into(),
536        };
537        let result = Ok(result);
538        let output = futures::stream::iter(vec![result]);
539        let str = format!("Bearer {token}");
540        let mut resp: Response<Pin<Box<dyn Stream<Item = Result<_, _>> + Send>>> =
541            Response::new(Box::pin(output));
542        let md = MetadataValue::try_from(str)
543            .map_err(|_| Status::invalid_argument("authorization not parsable"))?;
544        resp.metadata_mut().insert("authorization", md);
545        Ok(resp)
546    }
547
548    async fn do_get_fallback(
549        &self,
550        request: Request<Ticket>,
551        message: arrow_flight::sql::Any,
552    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
553        debug!("do_get_fallback type_url: {}", message.type_url);
554        let ctx = self.get_ctx(&request)?;
555        if !message.is::<protobuf::Action>() {
556            Err(Status::unimplemented(format!(
557                "do_get: The defined request is invalid: {}",
558                message.type_url
559            )))?
560        }
561
562        let action: protobuf::Action = message
563            .unpack()
564            .map_err(|e| Status::internal(format!("{e:?}")))?
565            .ok_or_else(|| Status::internal("Expected an Action but got None!"))?;
566        let fp = match &action.action_type {
567            Some(FetchPartition(fp)) => fp.clone(),
568            None => Err(Status::internal("Expected an ActionType but got None!"))?,
569        };
570
571        // Well-known job ID: respond with the data
572        match fp.job_id.as_str() {
573            "get_flight_info_table_types" => {
574                debug!("Responding with table types");
575                let rb = FlightSqlServiceImpl::table_types().map_err(|_| {
576                    Status::internal("Error getting table types".to_string())
577                })?;
578                let resp = Self::record_batch_to_resp(rb).await?;
579                return Ok(resp);
580            }
581            "get_flight_info_tables" => {
582                debug!("Responding with tables");
583                let rb = self
584                    .tables(ctx)
585                    .map_err(|_| Status::internal("Error getting tables".to_string()))?;
586                let resp = Self::record_batch_to_resp(rb).await?;
587                return Ok(resp);
588            }
589            _ => {}
590        }
591
592        // Proxy the flight
593        let addr = format!("http://{}:{}", fp.host, fp.port);
594        debug!("Scheduler proxying flight for to {}", addr);
595        let connection =
596            create_grpc_client_connection(addr.clone())
597                .await
598                .map_err(|e| {
599                    Status::internal(format!(
600                    "Error connecting to kapot scheduler or executor at {addr}: {e:?}"
601                ))
602                })?;
603        let mut flight_client = FlightServiceClient::new(connection);
604        let buf = action.encode_to_vec();
605        let request = Request::new(Ticket { ticket: buf.into() });
606
607        let stream = flight_client
608            .do_get(request)
609            .await
610            .map_err(|e| Status::internal(format!("{e:?}")))?
611            .into_inner();
612        Ok(Response::new(Box::pin(stream)))
613    }
614
615    /// Get a FlightDataStream containing the data related to the supported XDBC types.
616    async fn do_get_xdbc_type_info(
617        &self,
618        _query: CommandGetXdbcTypeInfo,
619        _request: Request<Ticket>,
620    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
621        debug!("do_get_xdbc_type_info");
622        Err(Status::unimplemented("Implement do_get_xdbc_type_info"))
623    }
624
625    async fn get_flight_info_statement(
626        &self,
627        query: CommandStatementQuery,
628        request: Request<FlightDescriptor>,
629    ) -> Result<Response<FlightInfo>, Status> {
630        debug!("get_flight_info_statement query:\n{}", query.query);
631
632        let ctx = self.get_ctx(&request)?;
633        let plan = Self::prepare_statement(&query.query, &ctx).await?;
634        let resp = self.execute_plan(ctx, &plan).await?;
635
636        debug!("Returning flight info...");
637        Ok(resp)
638    }
639
640    async fn get_flight_info_prepared_statement(
641        &self,
642        handle: CommandPreparedStatementQuery,
643        request: Request<FlightDescriptor>,
644    ) -> Result<Response<FlightInfo>, Status> {
645        debug!("get_flight_info_prepared_statement");
646        let ctx = self.get_ctx(&request)?;
647        let handle = Uuid::from_slice(handle.prepared_statement_handle.as_ref())
648            .map_err(|e| Status::internal(format!("Error decoding handle: {e}")))?;
649        let plan = self.get_plan(&handle)?;
650        let resp = self.execute_plan(ctx, &plan).await?;
651
652        debug!("Responding to query {}...", handle);
653        Ok(resp)
654    }
655
656    async fn get_flight_info_catalogs(
657        &self,
658        _query: CommandGetCatalogs,
659        _request: Request<FlightDescriptor>,
660    ) -> Result<Response<FlightInfo>, Status> {
661        debug!("get_flight_info_catalogs");
662        Err(Status::unimplemented("Implement get_flight_info_catalogs"))
663    }
664    async fn get_flight_info_schemas(
665        &self,
666        _query: CommandGetDbSchemas,
667        _request: Request<FlightDescriptor>,
668    ) -> Result<Response<FlightInfo>, Status> {
669        debug!("get_flight_info_schemas");
670        Err(Status::unimplemented("Implement get_flight_info_schemas"))
671    }
672
673    async fn get_flight_info_tables(
674        &self,
675        _query: CommandGetTables,
676        request: Request<FlightDescriptor>,
677    ) -> Result<Response<FlightInfo>, Status> {
678        debug!("get_flight_info_tables");
679        let ctx = self.get_ctx(&request)?;
680        let data = self
681            .tables(ctx)
682            .map_err(|e| Status::internal(format!("Error getting tables: {e}")))?;
683        let resp = self.batch_to_schema_resp(&data, "get_flight_info_tables")?;
684        Ok(resp)
685    }
686
687    async fn get_flight_info_table_types(
688        &self,
689        _query: CommandGetTableTypes,
690        _request: Request<FlightDescriptor>,
691    ) -> Result<Response<FlightInfo>, Status> {
692        debug!("get_flight_info_table_types");
693        let data = FlightSqlServiceImpl::table_types()
694            .map_err(|e| Status::internal(format!("Error getting table types: {e}")))?;
695        let resp = self.batch_to_schema_resp(&data, "get_flight_info_table_types")?;
696        Ok(resp)
697    }
698
699    async fn get_flight_info_sql_info(
700        &self,
701        _query: CommandGetSqlInfo,
702        _request: Request<FlightDescriptor>,
703    ) -> Result<Response<FlightInfo>, Status> {
704        debug!("get_flight_info_sql_info");
705        // TODO: implement for FlightSQL JDBC to work
706        Err(Status::unimplemented("Implement CommandGetSqlInfo"))
707    }
708    async fn get_flight_info_primary_keys(
709        &self,
710        _query: CommandGetPrimaryKeys,
711        _request: Request<FlightDescriptor>,
712    ) -> Result<Response<FlightInfo>, Status> {
713        debug!("get_flight_info_primary_keys");
714        Err(Status::unimplemented(
715            "Implement get_flight_info_primary_keys",
716        ))
717    }
718    async fn get_flight_info_exported_keys(
719        &self,
720        _query: CommandGetExportedKeys,
721        _request: Request<FlightDescriptor>,
722    ) -> Result<Response<FlightInfo>, Status> {
723        debug!("get_flight_info_exported_keys");
724        Err(Status::unimplemented(
725            "Implement get_flight_info_exported_keys",
726        ))
727    }
728    async fn get_flight_info_imported_keys(
729        &self,
730        _query: CommandGetImportedKeys,
731        _request: Request<FlightDescriptor>,
732    ) -> Result<Response<FlightInfo>, Status> {
733        debug!("get_flight_info_imported_keys");
734        Err(Status::unimplemented(
735            "Implement get_flight_info_imported_keys",
736        ))
737    }
738    async fn get_flight_info_cross_reference(
739        &self,
740        _query: CommandGetCrossReference,
741        _request: Request<FlightDescriptor>,
742    ) -> Result<Response<FlightInfo>, Status> {
743        debug!("get_flight_info_cross_reference");
744        Err(Status::unimplemented(
745            "Implement get_flight_info_cross_reference",
746        ))
747    }
748
749    /// Get a FlightInfo to extract information about the supported XDBC types.
750    async fn get_flight_info_xdbc_type_info(
751        &self,
752        _query: CommandGetXdbcTypeInfo,
753        _request: Request<FlightDescriptor>,
754    ) -> Result<Response<FlightInfo>, Status> {
755        debug!("get_flight_info_xdbc_type_info");
756        Err(Status::unimplemented(
757            "Implement get_flight_info_xdbc_type_info",
758        ))
759    }
760
761    async fn do_get_statement(
762        &self,
763        _ticket: TicketStatementQuery,
764        _request: Request<Ticket>,
765    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
766        debug!("do_get_statement");
767        // let handle = Uuid::from_slice(&ticket.statement_handle)
768        //     .map_err(|e| Status::internal(format!("Error decoding ticket: {}", e)))?;
769        // let statements = self.statements.try_lock()
770        //     .map_err(|e| Status::internal(format!("Error decoding ticket: {}", e)))?;
771        // let plan = statements.get(&handle);
772        Err(Status::unimplemented("Implement do_get_statement"))
773    }
774
775    async fn do_get_prepared_statement(
776        &self,
777        _query: CommandPreparedStatementQuery,
778        _request: Request<Ticket>,
779    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
780        debug!("do_get_prepared_statement");
781        Err(Status::unimplemented("Implement do_get_prepared_statement"))
782    }
783    async fn do_get_catalogs(
784        &self,
785        _query: CommandGetCatalogs,
786        _request: Request<Ticket>,
787    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
788        debug!("do_get_catalogs");
789        Err(Status::unimplemented("Implement do_get_catalogs"))
790    }
791    async fn do_get_schemas(
792        &self,
793        _query: CommandGetDbSchemas,
794        _request: Request<Ticket>,
795    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
796        debug!("do_get_schemas");
797        Err(Status::unimplemented("Implement do_get_schemas"))
798    }
799    async fn do_get_tables(
800        &self,
801        _query: CommandGetTables,
802        _request: Request<Ticket>,
803    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
804        debug!("do_get_tables");
805        Err(Status::unimplemented("Implement do_get_tables"))
806    }
807    async fn do_get_table_types(
808        &self,
809        _query: CommandGetTableTypes,
810        _request: Request<Ticket>,
811    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
812        debug!("do_get_table_types");
813        Err(Status::unimplemented("Implement do_get_table_types"))
814    }
815    async fn do_get_sql_info(
816        &self,
817        _query: CommandGetSqlInfo,
818        _request: Request<Ticket>,
819    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
820        debug!("do_get_sql_info");
821        Err(Status::unimplemented("Implement do_get_sql_info"))
822    }
823    async fn do_get_primary_keys(
824        &self,
825        _query: CommandGetPrimaryKeys,
826        _request: Request<Ticket>,
827    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
828        debug!("do_get_primary_keys");
829        Err(Status::unimplemented("Implement do_get_primary_keys"))
830    }
831    async fn do_get_exported_keys(
832        &self,
833        _query: CommandGetExportedKeys,
834        _request: Request<Ticket>,
835    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
836        debug!("do_get_exported_keys");
837        Err(Status::unimplemented("Implement do_get_exported_keys"))
838    }
839    async fn do_get_imported_keys(
840        &self,
841        _query: CommandGetImportedKeys,
842        _request: Request<Ticket>,
843    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
844        debug!("do_get_imported_keys");
845        Err(Status::unimplemented("Implement do_get_imported_keys"))
846    }
847    async fn do_get_cross_reference(
848        &self,
849        _query: CommandGetCrossReference,
850        _request: Request<Ticket>,
851    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
852        debug!("do_get_cross_reference");
853        Err(Status::unimplemented("Implement do_get_cross_reference"))
854    }
855    // do_put
856    async fn do_put_statement_update(
857        &self,
858        _ticket: CommandStatementUpdate,
859        _request: Request<PeekableFlightDataStream>,
860    ) -> Result<i64, Status> {
861        debug!("do_put_statement_update");
862        Err(Status::unimplemented("Implement do_put_statement_update"))
863    }
864    async fn do_put_prepared_statement_query(
865        &self,
866        _query: CommandPreparedStatementQuery,
867        _request: Request<PeekableFlightDataStream>,
868    ) -> Result<DoPutPreparedStatementResult, Status> {
869        debug!("do_put_prepared_statement_query");
870        Err(Status::unimplemented(
871            "Implement do_put_prepared_statement_query",
872        ))
873    }
874    async fn do_put_prepared_statement_update(
875        &self,
876        handle: CommandPreparedStatementUpdate,
877        request: Request<PeekableFlightDataStream>,
878    ) -> Result<i64, Status> {
879        debug!("do_put_prepared_statement_update");
880        let ctx = self.get_ctx(&request)?;
881        let handle = Uuid::from_slice(handle.prepared_statement_handle.as_ref())
882            .map_err(|e| Status::internal(format!("Error decoding handle: {e}")))?;
883        let plan = self.get_plan(&handle)?;
884        let _ = self.execute_plan(ctx, &plan).await?;
885        debug!("Sending -1 rows affected");
886        Ok(-1)
887    }
888
889    async fn do_action_create_prepared_statement(
890        &self,
891        query: ActionCreatePreparedStatementRequest,
892        request: Request<Action>,
893    ) -> Result<ActionCreatePreparedStatementResult, Status> {
894        debug!("do_action_create_prepared_statement");
895        let ctx = self.get_ctx(&request)?;
896        let plan = Self::prepare_statement(&query.query, &ctx).await?;
897        let schema_bytes = self.df_schema_to_arrow(plan.schema())?;
898        let handle = self.cache_plan(plan)?;
899        debug!("Prepared statement {}:\n{}", handle, query.query);
900        let res = ActionCreatePreparedStatementResult {
901            prepared_statement_handle: handle.as_bytes().to_vec().into(),
902            dataset_schema: schema_bytes.into(),
903            parameter_schema: Vec::new().into(), // TODO: parameters
904        };
905        Ok(res)
906    }
907
908    async fn do_action_close_prepared_statement(
909        &self,
910        handle: ActionClosePreparedStatementRequest,
911        _request: Request<Action>,
912    ) -> Result<(), Status> {
913        debug!("do_action_close_prepared_statement");
914        let handle = Uuid::from_slice(handle.prepared_statement_handle.as_ref())
915            .inspect(|id| {
916                debug!("Closing {}", id);
917            })
918            .map_err(|e| Status::internal(format!("Failed to parse handle: {e:?}")))?;
919
920        self.remove_plan(handle)
921    }
922
923    /// Get a FlightInfo for executing a substrait plan.
924    async fn get_flight_info_substrait_plan(
925        &self,
926        _query: CommandStatementSubstraitPlan,
927        _request: Request<FlightDescriptor>,
928    ) -> Result<Response<FlightInfo>, Status> {
929        debug!("get_flight_info_substrait_plan");
930        Err(Status::unimplemented(
931            "Implement get_flight_info_substrait_plan",
932        ))
933    }
934
935    /// Execute a substrait plan
936    async fn do_put_substrait_plan(
937        &self,
938        _query: CommandStatementSubstraitPlan,
939        _request: Request<PeekableFlightDataStream>,
940    ) -> Result<i64, Status> {
941        debug!("do_put_substrait_plan");
942        Err(Status::unimplemented("Implement do_put_substrait_plan"))
943    }
944
945    /// Create a prepared substrait plan.
946    async fn do_action_create_prepared_substrait_plan(
947        &self,
948        _query: ActionCreatePreparedSubstraitPlanRequest,
949        _request: Request<Action>,
950    ) -> Result<ActionCreatePreparedStatementResult, Status> {
951        debug!("do_action_create_prepared_substrait_plan");
952        Err(Status::unimplemented(
953            "Implement do_action_create_prepared_substrait_plan",
954        ))
955    }
956
957    /// Begin a transaction
958    async fn do_action_begin_transaction(
959        &self,
960        _query: ActionBeginTransactionRequest,
961        _request: Request<Action>,
962    ) -> Result<ActionBeginTransactionResult, Status> {
963        debug!("do_action_begin_transaction");
964        Err(Status::unimplemented(
965            "Implement do_action_begin_transaction",
966        ))
967    }
968
969    /// End a transaction
970    async fn do_action_end_transaction(
971        &self,
972        _query: ActionEndTransactionRequest,
973        _request: Request<Action>,
974    ) -> Result<(), Status> {
975        debug!("do_action_end_transaction");
976        Err(Status::unimplemented("Implement do_action_end_transaction"))
977    }
978
979    /// Begin a savepoint
980    async fn do_action_begin_savepoint(
981        &self,
982        _query: ActionBeginSavepointRequest,
983        _request: Request<Action>,
984    ) -> Result<ActionBeginSavepointResult, Status> {
985        debug!("do_action_begin_savepoint");
986        Err(Status::unimplemented("Implement do_action_begin_savepoint"))
987    }
988
989    /// End a savepoint
990    async fn do_action_end_savepoint(
991        &self,
992        _query: ActionEndSavepointRequest,
993        _request: Request<Action>,
994    ) -> Result<(), Status> {
995        debug!("do_action_end_savepoint");
996        Err(Status::unimplemented("Implement do_action_end_savepoint"))
997    }
998
999    /// Cancel a query
1000    async fn do_action_cancel_query(
1001        &self,
1002        _query: ActionCancelQueryRequest,
1003        _request: Request<Action>,
1004    ) -> Result<ActionCancelQueryResult, Status> {
1005        debug!("do_action_cancel_query");
1006        Err(Status::unimplemented("Implement do_action_cancel_query"))
1007    }
1008
1009    /// Register a new SqlInfo result, making it available when calling GetSqlInfo.
1010    async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
1011}