re_datafusion 0.31.3

High-level query APIs
Documentation
use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;

use crate::batch_coalescer::coalesce_exec::SizedCoalesceBatchesExec;
use crate::batch_coalescer::coalescer::CoalescerOptions;
use arrow::array::RecordBatch;
use arrow::datatypes::SchemaRef;
use async_trait::async_trait;
use datafusion::catalog::{Session, TableProvider};
use datafusion::common::not_impl_err;
use datafusion::error::{DataFusionError, Result as DataFusionResult};
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::logical_expr::TableProviderFilterPushDown;
use datafusion::logical_expr::dml::InsertOp;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::streaming::{PartitionStream, StreamingTableExec};
use datafusion::prelude::Expr;
use futures_util::StreamExt as _;
use tokio_stream::Stream;

#[async_trait]
pub trait GrpcStreamToTable:
    std::fmt::Debug + 'static + Send + Sync + Clone + std::marker::Unpin
{
    type GrpcStreamData;

    async fn fetch_schema(&mut self) -> DataFusionResult<SchemaRef>;

    fn process_response(&mut self, response: Self::GrpcStreamData)
    -> DataFusionResult<RecordBatch>;

    async fn send_streaming_request(
        &mut self,
    ) -> DataFusionResult<tonic::Response<tonic::Streaming<Self::GrpcStreamData>>>;

    fn supports_filters_pushdown(
        &self,
        filters: &[&Expr],
    ) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
        Ok(vec![
            TableProviderFilterPushDown::Unsupported;
            filters.len()
        ])
    }

    async fn insert_into(
        &self,
        _state: &dyn Session,
        _input: Arc<dyn ExecutionPlan>,
        _insert_op: InsertOp,
    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
        not_impl_err!("Insert into not implemented for this table")
    }
}

#[derive(Debug)]
pub struct GrpcStreamProvider<T: GrpcStreamToTable> {
    schema: SchemaRef,
    client: T,
}

impl<T: GrpcStreamToTable> GrpcStreamProvider<T> {
    pub async fn prepare(mut client: T) -> Result<Arc<Self>, DataFusionError> {
        let schema = client.fetch_schema().await?;
        Ok(Arc::new(Self { schema, client }))
    }
}

#[async_trait]
impl<T> TableProvider for GrpcStreamProvider<T>
where
    T: GrpcStreamToTable + Send + 'static,
    T::GrpcStreamData: Send + 'static,
{
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn schema(&self) -> SchemaRef {
        Arc::clone(&self.schema)
    }

    fn table_type(&self) -> datafusion::datasource::TableType {
        datafusion::datasource::TableType::Base
    }

    async fn scan(
        &self,
        _state: &dyn Session,
        projection: Option<&Vec<usize>>,
        _filters: &[Expr],
        limit: Option<usize>,
    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
        StreamingTableExec::try_new(
            self.schema.clone(),
            vec![Arc::new(GrpcStreamPartitionStream::new(
                &self.schema,
                self.client.clone(),
            ))],
            projection,
            Vec::default(),
            false,
            None,
        )
        .map(|e| Arc::new(e) as Arc<dyn ExecutionPlan>)
        .map(|exec| {
            Arc::new(SizedCoalesceBatchesExec::new(
                exec,
                CoalescerOptions {
                    target_batch_rows: crate::dataframe_query_common::DEFAULT_BATCH_ROWS,
                    target_batch_bytes: crate::dataframe_query_common::DEFAULT_BATCH_BYTES,
                    max_rows: limit,
                },
            )) as Arc<dyn ExecutionPlan>
        })
    }

    fn supports_filters_pushdown(
        &self,
        filters: &[&Expr],
    ) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
        self.client.supports_filters_pushdown(filters)
    }

    async fn insert_into(
        &self,
        state: &dyn Session,
        input: Arc<dyn ExecutionPlan>,
        insert_op: InsertOp,
    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
        self.client.insert_into(state, input, insert_op).await
    }
}

#[derive(Debug)]
pub struct GrpcStreamPartitionStream<T: GrpcStreamToTable> {
    schema: SchemaRef,
    client: T,
}

impl<T: GrpcStreamToTable> GrpcStreamPartitionStream<T> {
    fn new(schema: &SchemaRef, client: T) -> Self {
        Self {
            schema: Arc::clone(schema),
            client,
        }
    }
}

impl<T> PartitionStream for GrpcStreamPartitionStream<T>
where
    T: GrpcStreamToTable + Send + 'static,
    T::GrpcStreamData: Send + 'static,
{
    fn schema(&self) -> &SchemaRef {
        &self.schema
    }

    fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
        Box::pin(GrpcStream::execute(&self.schema, self.client.clone()))
    }
}

pub struct GrpcStream {
    schema: SchemaRef,
    adapted_stream: Pin<Box<dyn Stream<Item = datafusion::common::Result<RecordBatch>> + Send>>,
}

impl GrpcStream {
    fn execute<T>(schema: &SchemaRef, mut client: T) -> Self
    where
        T::GrpcStreamData: Send + 'static,
        T: GrpcStreamToTable + Send + 'static,
    {
        let adapted_stream = Box::pin(async_stream::try_stream! {
            let mut stream = client.send_streaming_request().await.map_err(|err| DataFusionError::External(Box::new(
                tonic::Status::internal(err.to_string())
            )))?.into_inner();

            while let Some(msg) = stream.next().await {
                let msg = msg.map_err(|err| DataFusionError::External(Box::new(err)))?;
                let processed = client
                    .process_response(msg)
                    .map_err(|err| DataFusionError::External(Box::new(
                        tonic::Status::internal(err.to_string())
                    )))?;
                yield processed;
            }
        });

        Self {
            schema: Arc::clone(schema),
            adapted_stream,
        }
    }
}

impl RecordBatchStream for GrpcStream {
    fn schema(&self) -> SchemaRef {
        Arc::clone(&self.schema)
    }
}

impl Stream for GrpcStream {
    type Item = DataFusionResult<RecordBatch>;

    fn poll_next(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Option<Self::Item>> {
        self.adapted_stream.poll_next_unpin(cx)
    }
}