use std::any::Any;
use std::collections::HashMap;
use std::error::Error;
use std::fmt::Debug;
use std::sync::Arc;
use crate::flight::exec::FlightExec;
use arrow_flight::error::FlightError;
use arrow_flight::FlightInfo;
use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::catalog::{Session, TableProviderFactory};
use datafusion::common::stats::Precision;
use datafusion::common::{DataFusionError, Statistics};
use datafusion::datasource::TableProvider;
use datafusion::logical_expr::{CreateExternalTable, Expr, TableType};
use datafusion::physical_plan::ExecutionPlan;
use serde::{Deserialize, Serialize};
use tonic::transport::{Channel, ClientTlsConfig};
pub mod codec;
mod exec;
pub mod sql;
pub use exec::enforce_schema;
#[derive(Clone, Debug)]
pub struct FlightTableFactory {
driver: Arc<dyn FlightDriver>,
}
impl FlightTableFactory {
pub fn new(driver: Arc<dyn FlightDriver>) -> Self {
Self { driver }
}
pub async fn open_table(
&self,
entry_point: impl Into<String>,
options: HashMap<String, String>,
) -> datafusion::common::Result<FlightTable> {
let origin = entry_point.into();
let channel = flight_channel(&origin).await?;
let metadata = self
.driver
.metadata(channel.clone(), &options)
.await
.map_err(to_df_err)?;
let num_rows = precision(metadata.info.total_records);
let total_byte_size = precision(metadata.info.total_bytes);
let logical_schema = metadata.schema.clone();
let stats = Statistics {
num_rows,
total_byte_size,
column_statistics: vec![],
};
let metadata_supplier = if metadata.props.reusable_flight_info {
MetadataSupplier::Reusable(Arc::new(metadata))
} else {
MetadataSupplier::Refresh {
driver: self.driver.clone(),
channel,
options,
}
};
Ok(FlightTable {
metadata_supplier,
origin,
logical_schema,
stats,
})
}
}
#[async_trait]
impl TableProviderFactory for FlightTableFactory {
async fn create(
&self,
_state: &dyn Session,
cmd: &CreateExternalTable,
) -> datafusion::common::Result<Arc<dyn TableProvider>> {
let table = self.open_table(&cmd.location, cmd.options.clone()).await?;
Ok(Arc::new(table))
}
}
#[async_trait]
pub trait FlightDriver: Sync + Send + Debug {
async fn metadata(
&self,
channel: Channel,
options: &HashMap<String, String>,
) -> arrow_flight::error::Result<FlightMetadata>;
}
#[derive(Clone, Debug)]
pub struct FlightMetadata {
info: FlightInfo,
props: FlightProperties,
schema: SchemaRef,
}
impl FlightMetadata {
pub fn new(info: FlightInfo, props: FlightProperties, schema: SchemaRef) -> Self {
Self {
info,
props,
schema,
}
}
pub fn try_new(info: FlightInfo, props: FlightProperties) -> arrow_flight::error::Result<Self> {
let schema = Arc::new(info.clone().try_decode_schema()?);
Ok(Self::new(info, props, schema))
}
}
impl TryFrom<FlightInfo> for FlightMetadata {
type Error = FlightError;
fn try_from(info: FlightInfo) -> Result<Self, Self::Error> {
Self::try_new(info, FlightProperties::default())
}
}
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub struct FlightProperties {
unbounded_streams: bool,
grpc_headers: HashMap<String, String>,
size_limits: SizeLimits,
reusable_flight_info: bool,
}
impl FlightProperties {
pub fn new() -> Self {
Default::default()
}
pub fn with_unbounded_streams(mut self, unbounded_streams: bool) -> Self {
self.unbounded_streams = unbounded_streams;
self
}
pub fn with_grpc_headers(mut self, grpc_headers: HashMap<String, String>) -> Self {
self.grpc_headers = grpc_headers;
self
}
pub fn with_size_limits(mut self, size_limits: SizeLimits) -> Self {
self.size_limits = size_limits;
self
}
pub fn with_reusable_flight_info(mut self, reusable_flight_info: bool) -> Self {
self.reusable_flight_info = reusable_flight_info;
self
}
}
#[derive(Copy, Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct SizeLimits {
encoding: usize,
decoding: usize,
}
impl SizeLimits {
pub fn new(encoding: usize, decoding: usize) -> Self {
Self { encoding, decoding }
}
}
impl Default for SizeLimits {
fn default() -> Self {
Self {
encoding: usize::MAX,
decoding: usize::MAX,
}
}
}
#[derive(Clone, Debug)]
enum MetadataSupplier {
Reusable(Arc<FlightMetadata>),
Refresh {
driver: Arc<dyn FlightDriver>,
channel: Channel,
options: HashMap<String, String>,
},
}
impl MetadataSupplier {
async fn flight_metadata(&self) -> datafusion::common::Result<Arc<FlightMetadata>> {
match self {
Self::Reusable(metadata) => Ok(metadata.clone()),
Self::Refresh {
driver,
channel,
options,
} => Ok(Arc::new(
driver
.metadata(channel.clone(), options)
.await
.map_err(to_df_err)?,
)),
}
}
}
pub struct FlightTable {
metadata_supplier: MetadataSupplier,
origin: String,
logical_schema: SchemaRef,
stats: Statistics,
}
impl std::fmt::Debug for FlightTable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FlightTable")
.field("origin", &self.origin)
.field("logical_schema", &self.logical_schema)
.field("stats", &self.stats)
.finish()
}
}
#[async_trait]
impl TableProvider for FlightTable {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.logical_schema.clone()
}
fn table_type(&self) -> TableType {
TableType::View
}
async fn scan(
&self,
_state: &dyn Session,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
let metadata = self.metadata_supplier.flight_metadata().await?;
Ok(Arc::new(FlightExec::try_new(
metadata.as_ref(),
projection,
&self.origin,
)?))
}
fn statistics(&self) -> Option<Statistics> {
Some(self.stats.clone())
}
}
fn to_df_err<E: Error + Send + Sync + 'static>(err: E) -> DataFusionError {
DataFusionError::External(Box::new(err))
}
async fn flight_channel(source: impl Into<String>) -> datafusion::common::Result<Channel> {
let tls_config = ClientTlsConfig::new().with_enabled_roots();
Channel::from_shared(source.into())
.map_err(to_df_err)?
.tls_config(tls_config)
.map_err(to_df_err)?
.connect()
.await
.map_err(to_df_err)
}
fn precision(total: i64) -> Precision<usize> {
if total < 0 {
Precision::Absent
} else {
Precision::Exact(total as usize)
}
}