datafusion_table_providers/flight/
sql.rs1use std::collections::HashMap;
21
22use arrow_flight::error::Result;
23use arrow_flight::sql::client::FlightSqlServiceClient;
24use async_trait::async_trait;
25use tonic::transport::Channel;
26
27use crate::flight::{FlightDriver, FlightMetadata, FlightProperties};
28
29pub const QUERY: &str = "flight.sql.query";
30pub const USERNAME: &str = "flight.sql.username";
31pub const PASSWORD: &str = "flight.sql.password";
32pub const HEADER_PREFIX: &str = "flight.sql.header.";
33
34#[derive(Clone, Debug, Default)]
43pub struct FlightSqlDriver {
44 properties_template: FlightProperties,
45 persistent_headers: bool,
46}
47
48impl FlightSqlDriver {
49 pub fn new() -> Self {
50 Default::default()
51 }
52
53 pub fn with_properties_template(mut self, properties_template: FlightProperties) -> Self {
58 self.properties_template = properties_template;
59 self
60 }
61
62 pub fn with_persistent_headers(mut self, persistent_headers: bool) -> Self {
64 self.persistent_headers = persistent_headers;
65 self
66 }
67}
68
69#[async_trait]
70impl FlightDriver for FlightSqlDriver {
71 async fn metadata(
72 &self,
73 channel: Channel,
74 options: &HashMap<String, String>,
75 ) -> Result<FlightMetadata> {
76 let mut client = FlightSqlServiceClient::new(channel);
77 let mut handshake_headers = self.properties_template.grpc_headers.clone();
78 let headers_overlay = options.iter().filter_map(|(key, value)| {
79 key.strip_prefix(HEADER_PREFIX)
80 .map(|header_name| (header_name.to_owned(), value.to_owned()))
81 });
82 handshake_headers.extend(headers_overlay);
83 for (name, value) in &handshake_headers {
84 client.set_header(name, value)
85 }
86 if let Some(username) = options.get(USERNAME) {
87 let default_password = "".to_string();
88 let password = options.get(PASSWORD).unwrap_or(&default_password);
89 client.handshake(username, password).await.ok();
90 }
91 let info = client.execute(options[QUERY].clone(), None).await?;
92 let mut partition_headers = if self.persistent_headers {
93 handshake_headers
94 } else {
95 HashMap::default()
96 };
97 if let Some(token) = client.token() {
98 partition_headers.insert("authorization".into(), format!("Bearer {token}"));
99 }
100 let props = self
101 .properties_template
102 .clone()
103 .with_grpc_headers(partition_headers);
104 FlightMetadata::try_new(info, props)
105 }
106}