datafusion_table_providers/
flight.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Generic [FlightTableFactory] that can connect to Arrow Flight services,
19//! with a [sql::FlightSqlDriver] provided out-of-the-box.
20
21use std::any::Any;
22use std::collections::HashMap;
23use std::error::Error;
24use std::fmt::Debug;
25use std::sync::Arc;
26
27use crate::flight::exec::FlightExec;
28use arrow_flight::error::FlightError;
29use arrow_flight::FlightInfo;
30use async_trait::async_trait;
31use datafusion::arrow::datatypes::SchemaRef;
32use datafusion::catalog::{Session, TableProviderFactory};
33use datafusion::common::stats::Precision;
34use datafusion::common::{DataFusionError, Statistics};
35use datafusion::datasource::TableProvider;
36use datafusion::logical_expr::{CreateExternalTable, Expr, TableType};
37use datafusion::physical_plan::ExecutionPlan;
38use serde::{Deserialize, Serialize};
39use tonic::transport::{Channel, ClientTlsConfig};
40
41pub mod codec;
42mod exec;
43pub mod sql;
44
45pub use exec::enforce_schema;
46
47/// Generic Arrow Flight data source. Requires a [FlightDriver] that allows implementors
48/// to integrate any custom Flight RPC service by producing a [FlightMetadata] for some DDL.
49///
50/// # Sample usage:
51/// ```
52/// # use arrow_flight::{FlightClient, FlightDescriptor};
53/// # use datafusion::prelude::SessionContext;
54/// # use datafusion_table_providers::flight::{FlightDriver, FlightMetadata, FlightTableFactory};
55/// # use std::collections::HashMap;
56/// # use std::sync::Arc;
57/// # use tonic::transport::Channel;
58///
59/// #[derive(Debug, Clone, Default)]
60/// struct CustomFlightDriver {}
61///
62/// #[async_trait::async_trait]
63/// impl FlightDriver for CustomFlightDriver {
64///     async fn metadata(
65///         &self,
66///         channel: Channel,
67///         opts: &HashMap<String, String>,
68///     ) -> arrow_flight::error::Result<FlightMetadata> {
69///         let mut client = FlightClient::new(channel);
70///         // for simplicity, we'll just assume the server expects a string command and no handshake
71///         let descriptor = FlightDescriptor::new_cmd(opts["flight.command"].clone());
72///         let flight_info = client.get_flight_info(descriptor).await?;
73///         FlightMetadata::try_from(flight_info)
74///     }
75/// }
76///
77/// #[tokio::main]
78/// async fn main() {
79///     let ctx = SessionContext::new();
80///     ctx.state_ref().write().table_factories_mut().insert(
81///         "CUSTOM_FLIGHT".into(),
82///         Arc::new(FlightTableFactory::new(Arc::new(
83///             CustomFlightDriver::default(),
84///         ))),
85///     );
86///     let _ = ctx.sql(
87///         r#"
88///         CREATE EXTERNAL TABLE custom_flight_table STORED AS CUSTOM_FLIGHT
89///         LOCATION 'https://custom.flight.rpc'
90///         OPTIONS ('flight.command' 'AI, show me the data!')
91///     "#,
92///     ); // no .await here, so we don't actually try to connect to the bogus URL
93/// }
94/// ```
95#[derive(Clone, Debug)]
96pub struct FlightTableFactory {
97    driver: Arc<dyn FlightDriver>,
98}
99
100impl FlightTableFactory {
101    /// Create a data source using the provided driver
102    pub fn new(driver: Arc<dyn FlightDriver>) -> Self {
103        Self { driver }
104    }
105
106    /// Convenient way to create a [FlightTable] programatically, as an alternative to DDL.
107    pub async fn open_table(
108        &self,
109        entry_point: impl Into<String>,
110        options: HashMap<String, String>,
111    ) -> datafusion::common::Result<FlightTable> {
112        let origin = entry_point.into();
113        let channel = flight_channel(&origin).await?;
114        let metadata = self
115            .driver
116            .metadata(channel.clone(), &options)
117            .await
118            .map_err(to_df_err)?;
119        let num_rows = precision(metadata.info.total_records);
120        let total_byte_size = precision(metadata.info.total_bytes);
121        let logical_schema = metadata.schema.clone();
122        let stats = Statistics {
123            num_rows,
124            total_byte_size,
125            column_statistics: vec![],
126        };
127        let metadata_supplier = if metadata.props.reusable_flight_info {
128            MetadataSupplier::Reusable(Arc::new(metadata))
129        } else {
130            MetadataSupplier::Refresh {
131                driver: self.driver.clone(),
132                channel,
133                options,
134            }
135        };
136        Ok(FlightTable {
137            metadata_supplier,
138            origin,
139            logical_schema,
140            stats,
141        })
142    }
143}
144
145#[async_trait]
146impl TableProviderFactory for FlightTableFactory {
147    async fn create(
148        &self,
149        _state: &dyn Session,
150        cmd: &CreateExternalTable,
151    ) -> datafusion::common::Result<Arc<dyn TableProvider>> {
152        let table = self.open_table(&cmd.location, cmd.options.clone()).await?;
153        Ok(Arc::new(table))
154    }
155}
156
157/// Extension point for integrating any Flight RPC service as a [FlightTableFactory].
158/// Handles the initial `GetFlightInfo` call and all its prerequisites (such as `Handshake`),
159/// to produce a [FlightMetadata].
160#[async_trait]
161pub trait FlightDriver: Sync + Send + Debug {
162    /// Returns a [FlightMetadata] from the specified channel,
163    /// according to the provided table options.
164    /// The driver must provide at least a [FlightInfo] in order to construct a flight metadata.
165    async fn metadata(
166        &self,
167        channel: Channel,
168        options: &HashMap<String, String>,
169    ) -> arrow_flight::error::Result<FlightMetadata>;
170}
171
172/// The information that a [FlightDriver] must produce
173/// in order to register flights as DataFusion tables.
174#[derive(Clone, Debug)]
175pub struct FlightMetadata {
176    /// FlightInfo object produced by the driver
177    info: FlightInfo,
178    /// Various knobs that control execution
179    props: FlightProperties,
180    /// Arrow schema. Can be enforced by the driver or inferred from the FlightInfo
181    schema: SchemaRef,
182}
183
184impl FlightMetadata {
185    /// Customize everything that is in the driver's control
186    pub fn new(info: FlightInfo, props: FlightProperties, schema: SchemaRef) -> Self {
187        Self {
188            info,
189            props,
190            schema,
191        }
192    }
193
194    /// Customize flight properties and try to use the FlightInfo schema
195    pub fn try_new(info: FlightInfo, props: FlightProperties) -> arrow_flight::error::Result<Self> {
196        let schema = Arc::new(info.clone().try_decode_schema()?);
197        Ok(Self::new(info, props, schema))
198    }
199}
200
201impl TryFrom<FlightInfo> for FlightMetadata {
202    type Error = FlightError;
203
204    fn try_from(info: FlightInfo) -> Result<Self, Self::Error> {
205        Self::try_new(info, FlightProperties::default())
206    }
207}
208
209/// Meant to gradually encapsulate all sorts of knobs required
210/// for controlling the protocol and query execution details.
211#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
212pub struct FlightProperties {
213    unbounded_streams: bool,
214    grpc_headers: HashMap<String, String>,
215    size_limits: SizeLimits,
216    reusable_flight_info: bool,
217}
218
219impl FlightProperties {
220    pub fn new() -> Self {
221        Default::default()
222    }
223
224    /// Whether the service will produce infinite streams
225    pub fn with_unbounded_streams(mut self, unbounded_streams: bool) -> Self {
226        self.unbounded_streams = unbounded_streams;
227        self
228    }
229
230    /// gRPC headers to use on subsequent calls.
231    pub fn with_grpc_headers(mut self, grpc_headers: HashMap<String, String>) -> Self {
232        self.grpc_headers = grpc_headers;
233        self
234    }
235
236    /// Max sizes in bytes for encoded/decoded gRPC messages.
237    pub fn with_size_limits(mut self, size_limits: SizeLimits) -> Self {
238        self.size_limits = size_limits;
239        self
240    }
241
242    /// Whether the FlightInfo objects produced by the service can be used multiple times
243    /// or need to be refreshed before every table scan.
244    pub fn with_reusable_flight_info(mut self, reusable_flight_info: bool) -> Self {
245        self.reusable_flight_info = reusable_flight_info;
246        self
247    }
248}
249
250/// Message size limits to be passed to the underlying gRPC library.
251#[derive(Copy, Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
252pub struct SizeLimits {
253    encoding: usize,
254    decoding: usize,
255}
256
257impl SizeLimits {
258    pub fn new(encoding: usize, decoding: usize) -> Self {
259        Self { encoding, decoding }
260    }
261}
262
263impl Default for SizeLimits {
264    fn default() -> Self {
265        Self {
266            // no limits
267            encoding: usize::MAX,
268            decoding: usize::MAX,
269        }
270    }
271}
272
273#[derive(Clone, Debug)]
274enum MetadataSupplier {
275    Reusable(Arc<FlightMetadata>),
276    Refresh {
277        driver: Arc<dyn FlightDriver>,
278        channel: Channel,
279        options: HashMap<String, String>,
280    },
281}
282
283impl MetadataSupplier {
284    async fn flight_metadata(&self) -> datafusion::common::Result<Arc<FlightMetadata>> {
285        match self {
286            Self::Reusable(metadata) => Ok(metadata.clone()),
287            Self::Refresh {
288                driver,
289                channel,
290                options,
291            } => Ok(Arc::new(
292                driver
293                    .metadata(channel.clone(), options)
294                    .await
295                    .map_err(to_df_err)?,
296            )),
297        }
298    }
299}
300
301/// Table provider that wraps a specific flight from an Arrow Flight service
302pub struct FlightTable {
303    metadata_supplier: MetadataSupplier,
304    origin: String,
305    logical_schema: SchemaRef,
306    stats: Statistics,
307}
308
309impl std::fmt::Debug for FlightTable {
310    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311        f.debug_struct("FlightTable")
312            .field("origin", &self.origin)
313            .field("logical_schema", &self.logical_schema)
314            .field("stats", &self.stats)
315            .finish()
316    }
317}
318
319#[async_trait]
320impl TableProvider for FlightTable {
321    fn as_any(&self) -> &dyn Any {
322        self
323    }
324
325    fn schema(&self) -> SchemaRef {
326        self.logical_schema.clone()
327    }
328
329    fn table_type(&self) -> TableType {
330        TableType::View
331    }
332
333    async fn scan(
334        &self,
335        _state: &dyn Session,
336        projection: Option<&Vec<usize>>,
337        _filters: &[Expr],
338        _limit: Option<usize>,
339    ) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
340        let metadata = self.metadata_supplier.flight_metadata().await?;
341        Ok(Arc::new(FlightExec::try_new(
342            metadata.as_ref(),
343            projection,
344            &self.origin,
345        )?))
346    }
347
348    fn statistics(&self) -> Option<Statistics> {
349        Some(self.stats.clone())
350    }
351}
352
353fn to_df_err<E: Error + Send + Sync + 'static>(err: E) -> DataFusionError {
354    DataFusionError::External(Box::new(err))
355}
356
357async fn flight_channel(source: impl Into<String>) -> datafusion::common::Result<Channel> {
358    let tls_config = ClientTlsConfig::new().with_enabled_roots();
359    Channel::from_shared(source.into())
360        .map_err(to_df_err)?
361        .tls_config(tls_config)
362        .map_err(to_df_err)?
363        .connect()
364        .await
365        .map_err(to_df_err)
366}
367
368fn precision(total: i64) -> Precision<usize> {
369    if total < 0 {
370        Precision::Absent
371    } else {
372        Precision::Exact(total as usize)
373    }
374}