datafusion_table_providers/
flight.rs1use std::any::Any;
22use std::collections::HashMap;
23use std::error::Error;
24use std::fmt::Debug;
25use std::sync::Arc;
26
27use crate::flight::exec::FlightExec;
28use arrow_flight::error::FlightError;
29use arrow_flight::FlightInfo;
30use async_trait::async_trait;
31use datafusion::arrow::datatypes::SchemaRef;
32use datafusion::catalog::{Session, TableProviderFactory};
33use datafusion::common::stats::Precision;
34use datafusion::common::{DataFusionError, Statistics};
35use datafusion::datasource::TableProvider;
36use datafusion::logical_expr::{CreateExternalTable, Expr, TableType};
37use datafusion::physical_plan::ExecutionPlan;
38use serde::{Deserialize, Serialize};
39use tonic::transport::{Channel, ClientTlsConfig};
40
41pub mod codec;
42mod exec;
43pub mod sql;
44
45pub use exec::enforce_schema;
46
47#[derive(Clone, Debug)]
96pub struct FlightTableFactory {
97 driver: Arc<dyn FlightDriver>,
98}
99
100impl FlightTableFactory {
101 pub fn new(driver: Arc<dyn FlightDriver>) -> Self {
103 Self { driver }
104 }
105
106 pub async fn open_table(
108 &self,
109 entry_point: impl Into<String>,
110 options: HashMap<String, String>,
111 ) -> datafusion::common::Result<FlightTable> {
112 let origin = entry_point.into();
113 let channel = flight_channel(&origin).await?;
114 let metadata = self
115 .driver
116 .metadata(channel.clone(), &options)
117 .await
118 .map_err(to_df_err)?;
119 let num_rows = precision(metadata.info.total_records);
120 let total_byte_size = precision(metadata.info.total_bytes);
121 let logical_schema = metadata.schema.clone();
122 let stats = Statistics {
123 num_rows,
124 total_byte_size,
125 column_statistics: vec![],
126 };
127 let metadata_supplier = if metadata.props.reusable_flight_info {
128 MetadataSupplier::Reusable(Arc::new(metadata))
129 } else {
130 MetadataSupplier::Refresh {
131 driver: self.driver.clone(),
132 channel,
133 options,
134 }
135 };
136 Ok(FlightTable {
137 metadata_supplier,
138 origin,
139 logical_schema,
140 stats,
141 })
142 }
143}
144
145#[async_trait]
146impl TableProviderFactory for FlightTableFactory {
147 async fn create(
148 &self,
149 _state: &dyn Session,
150 cmd: &CreateExternalTable,
151 ) -> datafusion::common::Result<Arc<dyn TableProvider>> {
152 let table = self.open_table(&cmd.location, cmd.options.clone()).await?;
153 Ok(Arc::new(table))
154 }
155}
156
157#[async_trait]
161pub trait FlightDriver: Sync + Send + Debug {
162 async fn metadata(
166 &self,
167 channel: Channel,
168 options: &HashMap<String, String>,
169 ) -> arrow_flight::error::Result<FlightMetadata>;
170}
171
172#[derive(Clone, Debug)]
175pub struct FlightMetadata {
176 info: FlightInfo,
178 props: FlightProperties,
180 schema: SchemaRef,
182}
183
184impl FlightMetadata {
185 pub fn new(info: FlightInfo, props: FlightProperties, schema: SchemaRef) -> Self {
187 Self {
188 info,
189 props,
190 schema,
191 }
192 }
193
194 pub fn try_new(info: FlightInfo, props: FlightProperties) -> arrow_flight::error::Result<Self> {
196 let schema = Arc::new(info.clone().try_decode_schema()?);
197 Ok(Self::new(info, props, schema))
198 }
199}
200
201impl TryFrom<FlightInfo> for FlightMetadata {
202 type Error = FlightError;
203
204 fn try_from(info: FlightInfo) -> Result<Self, Self::Error> {
205 Self::try_new(info, FlightProperties::default())
206 }
207}
208
209#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
212pub struct FlightProperties {
213 unbounded_streams: bool,
214 grpc_headers: HashMap<String, String>,
215 size_limits: SizeLimits,
216 reusable_flight_info: bool,
217}
218
219impl FlightProperties {
220 pub fn new() -> Self {
221 Default::default()
222 }
223
224 pub fn with_unbounded_streams(mut self, unbounded_streams: bool) -> Self {
226 self.unbounded_streams = unbounded_streams;
227 self
228 }
229
230 pub fn with_grpc_headers(mut self, grpc_headers: HashMap<String, String>) -> Self {
232 self.grpc_headers = grpc_headers;
233 self
234 }
235
236 pub fn with_size_limits(mut self, size_limits: SizeLimits) -> Self {
238 self.size_limits = size_limits;
239 self
240 }
241
242 pub fn with_reusable_flight_info(mut self, reusable_flight_info: bool) -> Self {
245 self.reusable_flight_info = reusable_flight_info;
246 self
247 }
248}
249
250#[derive(Copy, Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
252pub struct SizeLimits {
253 encoding: usize,
254 decoding: usize,
255}
256
257impl SizeLimits {
258 pub fn new(encoding: usize, decoding: usize) -> Self {
259 Self { encoding, decoding }
260 }
261}
262
263impl Default for SizeLimits {
264 fn default() -> Self {
265 Self {
266 encoding: usize::MAX,
268 decoding: usize::MAX,
269 }
270 }
271}
272
273#[derive(Clone, Debug)]
274enum MetadataSupplier {
275 Reusable(Arc<FlightMetadata>),
276 Refresh {
277 driver: Arc<dyn FlightDriver>,
278 channel: Channel,
279 options: HashMap<String, String>,
280 },
281}
282
283impl MetadataSupplier {
284 async fn flight_metadata(&self) -> datafusion::common::Result<Arc<FlightMetadata>> {
285 match self {
286 Self::Reusable(metadata) => Ok(metadata.clone()),
287 Self::Refresh {
288 driver,
289 channel,
290 options,
291 } => Ok(Arc::new(
292 driver
293 .metadata(channel.clone(), options)
294 .await
295 .map_err(to_df_err)?,
296 )),
297 }
298 }
299}
300
301pub struct FlightTable {
303 metadata_supplier: MetadataSupplier,
304 origin: String,
305 logical_schema: SchemaRef,
306 stats: Statistics,
307}
308
309impl std::fmt::Debug for FlightTable {
310 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311 f.debug_struct("FlightTable")
312 .field("origin", &self.origin)
313 .field("logical_schema", &self.logical_schema)
314 .field("stats", &self.stats)
315 .finish()
316 }
317}
318
319#[async_trait]
320impl TableProvider for FlightTable {
321 fn as_any(&self) -> &dyn Any {
322 self
323 }
324
325 fn schema(&self) -> SchemaRef {
326 self.logical_schema.clone()
327 }
328
329 fn table_type(&self) -> TableType {
330 TableType::View
331 }
332
333 async fn scan(
334 &self,
335 _state: &dyn Session,
336 projection: Option<&Vec<usize>>,
337 _filters: &[Expr],
338 _limit: Option<usize>,
339 ) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
340 let metadata = self.metadata_supplier.flight_metadata().await?;
341 Ok(Arc::new(FlightExec::try_new(
342 metadata.as_ref(),
343 projection,
344 &self.origin,
345 )?))
346 }
347
348 fn statistics(&self) -> Option<Statistics> {
349 Some(self.stats.clone())
350 }
351}
352
353fn to_df_err<E: Error + Send + Sync + 'static>(err: E) -> DataFusionError {
354 DataFusionError::External(Box::new(err))
355}
356
357async fn flight_channel(source: impl Into<String>) -> datafusion::common::Result<Channel> {
358 let tls_config = ClientTlsConfig::new().with_enabled_roots();
359 Channel::from_shared(source.into())
360 .map_err(to_df_err)?
361 .tls_config(tls_config)
362 .map_err(to_df_err)?
363 .connect()
364 .await
365 .map_err(to_df_err)
366}
367
368fn precision(total: i64) -> Precision<usize> {
369 if total < 0 {
370 Precision::Absent
371 } else {
372 Precision::Exact(total as usize)
373 }
374}