use std::any::Any;
use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use arrow::array::RecordBatch;
use arrow::datatypes::SchemaRef;
use async_trait::async_trait;
use datafusion::catalog::Session;
use datafusion::common::{DataFusionError, Result as DataFusionResult, exec_err};
use datafusion::datasource::TableType;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::logical_expr::TableProviderFilterPushDown;
use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties as _, PlanProperties,
};
use datafusion::prelude::Expr;
use datafusion::{catalog::TableProvider, datasource::MemTable};
use futures::{Stream, StreamExt as _};
use re_mutex::Mutex;
use re_viewer_context::AsyncRuntimeHandle;
#[derive(Debug, Clone)]
pub enum CacheState {
NotStarted,
Streaming,
Complete(Arc<MemTable>),
Failed(Arc<DataFusionError>),
}
#[derive(Debug)]
struct StreamingCacheInner {
schema: SchemaRef,
cached_batches: Vec<RecordBatch>,
state: CacheState,
wakers: Vec<Waker>,
}
impl StreamingCacheInner {
fn new(schema: SchemaRef) -> Self {
Self {
schema,
cached_batches: Vec::new(),
state: CacheState::NotStarted,
wakers: Vec::new(),
}
}
fn wake_all(&mut self) {
for waker in self.wakers.drain(..) {
waker.wake();
}
}
fn register_waker(&mut self, waker: &Waker) {
if !self.wakers.iter().any(|w| w.will_wake(waker)) {
self.wakers.push(waker.clone());
}
}
}
pub struct StreamingCacheTableProvider {
input_table: Arc<dyn TableProvider>,
schema: SchemaRef,
cache: Mutex<Arc<Mutex<StreamingCacheInner>>>,
runtime: AsyncRuntimeHandle,
}
impl fmt::Debug for StreamingCacheTableProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let inner_cache = self.cache.lock();
let inner = inner_cache.lock();
f.debug_struct("StreamingCacheTableProvider")
.field("schema", &self.schema)
.field("state", &inner.state)
.field("cached_batches", &inner.cached_batches.len())
.finish_non_exhaustive()
}
}
impl StreamingCacheTableProvider {
pub fn new(input_table: Arc<dyn TableProvider>, runtime: AsyncRuntimeHandle) -> Self {
let schema = input_table.schema();
Self {
input_table,
schema: Arc::clone(&schema),
cache: Mutex::new(Arc::new(Mutex::new(StreamingCacheInner::new(schema)))),
runtime,
}
}
pub fn refresh(&self) {
let mut cache = self.cache.lock();
*cache = Arc::new(Mutex::new(StreamingCacheInner::new(Arc::clone(
&self.schema,
))));
}
pub fn is_complete(&self) -> bool {
matches!(self.cache.lock().lock().state, CacheState::Complete(_))
}
pub fn cached_batch_count(&self) -> usize {
self.cache.lock().lock().cached_batches.len()
}
pub fn state(&self) -> CacheState {
self.cache.lock().lock().state.clone()
}
async fn stream_to_cache(
input_exec: Arc<dyn ExecutionPlan>,
task_ctx: Arc<TaskContext>,
cache: &Arc<Mutex<StreamingCacheInner>>,
) -> DataFusionResult<()> {
if input_exec.output_partitioning().partition_count() != 1 {
return exec_err!(
"Expected exactly one partition stream for input to StreamingCacheTableProvider"
);
}
let mut stream = input_exec.execute(0, task_ctx)?;
while let Some(result) = stream.next().await {
let batch = result?;
let mut guard = cache.lock();
guard.cached_batches.push(batch);
guard.wake_all();
if Arc::strong_count(cache) == 1 {
return Ok(());
}
}
let mut guard = cache.lock();
let batches = guard.cached_batches.clone();
let mem_table = MemTable::try_new(Arc::clone(&guard.schema), vec![batches])?;
guard.state = CacheState::Complete(Arc::new(mem_table));
guard.wake_all();
Ok(())
}
}
#[async_trait]
impl TableProvider for StreamingCacheTableProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
state: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
enum Action {
UseMemTable(Arc<MemTable>),
ReturnError(Arc<DataFusionError>),
StartStreaming(Arc<Mutex<StreamingCacheInner>>),
ReturnStreamingPlan(Arc<Mutex<StreamingCacheInner>>),
}
let action = {
let inner_cache = Arc::clone(&self.cache.lock());
let mut inner = inner_cache.lock();
match &inner.state {
CacheState::Complete(mem_table) => Action::UseMemTable(Arc::clone(mem_table)),
CacheState::Failed(err) => Action::ReturnError(Arc::clone(err)),
CacheState::NotStarted => {
inner.state = CacheState::Streaming;
drop(inner);
Action::StartStreaming(inner_cache)
}
CacheState::Streaming => {
drop(inner);
Action::ReturnStreamingPlan(inner_cache)
}
}
};
match action {
Action::UseMemTable(mem_table) => {
mem_table.scan(state, projection, filters, limit).await
}
Action::ReturnError(error) => Err(DataFusionError::Shared(error)),
Action::StartStreaming(inner_cache) => {
let input_exec = self
.input_table
.scan(state, projection, filters, limit)
.await?;
let cache_ref = Arc::clone(&inner_cache);
let task_ctx = state.task_ctx();
self.runtime.spawn_future(async move {
if let Err(err) = Self::stream_to_cache(input_exec, task_ctx, &cache_ref).await
{
let mut guard = cache_ref.lock();
guard.state = CacheState::Failed(Arc::new(err));
guard.wake_all();
}
});
Ok(Arc::new(CachedStreamingExec::new(inner_cache)))
}
Action::ReturnStreamingPlan(inner_cache) => {
Ok(Arc::new(CachedStreamingExec::new(inner_cache)))
}
}
}
fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
Ok(vec![
TableProviderFilterPushDown::Unsupported;
filters.len()
])
}
}
struct CachedStreamingExec {
cache: Arc<Mutex<StreamingCacheInner>>,
properties: PlanProperties,
}
impl CachedStreamingExec {
fn new(cache: Arc<Mutex<StreamingCacheInner>>) -> Self {
let schema = Arc::clone(&cache.lock().schema);
let properties = PlanProperties::new(
EquivalenceProperties::new(schema),
Partitioning::UnknownPartitioning(1),
EmissionType::Incremental,
Boundedness::Bounded,
);
Self { cache, properties }
}
}
impl fmt::Debug for CachedStreamingExec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CachedStreamingExec")
.field("schema", &self.cache.lock().schema)
.finish_non_exhaustive()
}
}
impl DisplayAs for CachedStreamingExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "CachedStreamingExec")
}
}
impl ExecutionPlan for CachedStreamingExec {
fn name(&self) -> &'static str {
"CachedStreamingExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.cache.lock().schema)
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
if !children.is_empty() {
return Err(DataFusionError::Internal(
"CachedStreamingExec expects no children".to_owned(),
));
}
Ok(self)
}
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
) -> DataFusionResult<SendableRecordBatchStream> {
if partition != 0 {
return Err(DataFusionError::Internal(format!(
"CachedStreamingExec only supports partition 0, got {partition}"
)));
}
Ok(Box::pin(CachedRecordBatchStream::new(Arc::clone(
&self.cache,
))))
}
}
pub struct CachedRecordBatchStream {
cache: Arc<Mutex<StreamingCacheInner>>,
read_pos: usize,
error_yielded: bool,
}
impl CachedRecordBatchStream {
fn new(cache: Arc<Mutex<StreamingCacheInner>>) -> Self {
Self {
cache,
read_pos: 0,
error_yielded: false,
}
}
}
impl Stream for CachedRecordBatchStream {
type Item = DataFusionResult<RecordBatch>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.error_yielded {
return Poll::Ready(None);
}
let mut cache = self.cache.lock();
if self.read_pos < cache.cached_batches.len() {
let batch = cache.cached_batches[self.read_pos].clone();
drop(cache);
self.read_pos += 1;
return Poll::Ready(Some(Ok(batch)));
}
match &cache.state {
CacheState::Complete(_) => Poll::Ready(None),
CacheState::Failed(err) => {
let err = Arc::clone(err);
drop(cache);
self.error_yielded = true;
Poll::Ready(Some(Err(DataFusionError::Shared(err))))
}
CacheState::NotStarted | CacheState::Streaming => {
cache.register_waker(cx.waker());
Poll::Pending
}
}
}
}
impl RecordBatchStream for CachedRecordBatchStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.cache.lock().schema)
}
}