datafusion_table_providers/flight/
sql.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Default [FlightDriver] for Flight SQL
19
20use std::collections::HashMap;
21
22use arrow_flight::error::Result;
23use arrow_flight::sql::client::FlightSqlServiceClient;
24use async_trait::async_trait;
25use tonic::transport::Channel;
26
27use crate::flight::{FlightDriver, FlightMetadata, FlightProperties};
28
29pub const QUERY: &str = "flight.sql.query";
30pub const USERNAME: &str = "flight.sql.username";
31pub const PASSWORD: &str = "flight.sql.password";
32pub const HEADER_PREFIX: &str = "flight.sql.header.";
33
34/// Default Flight SQL driver. Requires a [QUERY] to be passed as a table option.
35/// If [USERNAME] (and optionally [PASSWORD]) are passed,
36/// will perform the `Handshake` using basic authentication.
37/// Any additional headers for the `GetFlightInfo` call can be passed as table options
38/// using the [HEADER_PREFIX] prefix.
39/// If a token is returned by the server with the handshake response, it will be
40/// stored as a gRPC authorization header within the returned [FlightMetadata],
41/// to be sent with the subsequent `DoGet` requests.
42#[derive(Clone, Debug, Default)]
43pub struct FlightSqlDriver {
44    properties_template: FlightProperties,
45    persistent_headers: bool,
46}
47
48impl FlightSqlDriver {
49    pub fn new() -> Self {
50        Default::default()
51    }
52
53    /// Custom flight properties to be returned from the metadata call instead of the default ones.
54    /// The headers (if any) will only be used for the Handshake/GetFlightInfo calls by default.
55    /// This behaviour can be changed by calling [Self::with_persistent_headers] below.
56    /// Headers provided as options for the metadata call will overwrite the template ones.
57    pub fn with_properties_template(mut self, properties_template: FlightProperties) -> Self {
58        self.properties_template = properties_template;
59        self
60    }
61
62    /// Propagate the static headers configured for Handshake/GetFlightInfo to the subsequent DoGet calls.
63    pub fn with_persistent_headers(mut self, persistent_headers: bool) -> Self {
64        self.persistent_headers = persistent_headers;
65        self
66    }
67}
68
69#[async_trait]
70impl FlightDriver for FlightSqlDriver {
71    async fn metadata(
72        &self,
73        channel: Channel,
74        options: &HashMap<String, String>,
75    ) -> Result<FlightMetadata> {
76        let mut client = FlightSqlServiceClient::new(channel);
77        let mut handshake_headers = self.properties_template.grpc_headers.clone();
78        let headers_overlay = options.iter().filter_map(|(key, value)| {
79            key.strip_prefix(HEADER_PREFIX)
80                .map(|header_name| (header_name.to_owned(), value.to_owned()))
81        });
82        handshake_headers.extend(headers_overlay);
83        for (name, value) in &handshake_headers {
84            client.set_header(name, value)
85        }
86        if let Some(username) = options.get(USERNAME) {
87            let default_password = "".to_string();
88            let password = options.get(PASSWORD).unwrap_or(&default_password);
89            client.handshake(username, password).await.ok();
90        }
91        let info = client.execute(options[QUERY].clone(), None).await?;
92        let mut partition_headers = if self.persistent_headers {
93            handshake_headers
94        } else {
95            HashMap::default()
96        };
97        if let Some(token) = client.token() {
98            partition_headers.insert("authorization".into(), format!("Bearer {token}"));
99        }
100        let props = self
101            .properties_template
102            .clone()
103            .with_grpc_headers(partition_headers);
104        FlightMetadata::try_new(info, props)
105    }
106}