use std::any::Any;
use std::sync::Arc;
use std::task::{Context, Poll};
use super::{ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream};
use crate::error::{DataFusionError, Result};
use arrow::datatypes::SchemaRef;
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use futures::Stream;
#[derive(Debug)]
pub struct MemoryExec {
partitions: Vec<Vec<RecordBatch>>,
schema: SchemaRef,
projection: Option<Vec<usize>>,
}
#[async_trait]
impl ExecutionPlan for MemoryExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![]
}
fn output_partitioning(&self) -> Partitioning {
Partitioning::UnknownPartitioning(self.partitions.len())
}
fn with_new_children(
&self,
_: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Err(DataFusionError::Internal(format!(
"Children cannot be replaced in {:?}",
self
)))
}
async fn execute(&self, partition: usize) -> Result<SendableRecordBatchStream> {
Ok(Box::pin(MemoryStream::try_new(
self.partitions[partition].clone(),
self.schema.clone(),
self.projection.clone(),
)?))
}
}
impl MemoryExec {
pub fn try_new(
partitions: &[Vec<RecordBatch>],
schema: SchemaRef,
projection: Option<Vec<usize>>,
) -> Result<Self> {
Ok(Self {
partitions: partitions.to_vec(),
schema,
projection,
})
}
}
pub(crate) struct MemoryStream {
data: Vec<RecordBatch>,
schema: SchemaRef,
projection: Option<Vec<usize>>,
index: usize,
}
impl MemoryStream {
pub fn try_new(
data: Vec<RecordBatch>,
schema: SchemaRef,
projection: Option<Vec<usize>>,
) -> Result<Self> {
Ok(Self {
data,
schema,
projection,
index: 0,
})
}
}
impl Stream for MemoryStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
Poll::Ready(if self.index < self.data.len() {
self.index += 1;
let batch = &self.data[self.index - 1];
match &self.projection {
Some(columns) => Some(RecordBatch::try_new(
self.schema.clone(),
columns.iter().map(|i| batch.column(*i).clone()).collect(),
)),
None => Some(Ok(batch.clone())),
}
} else {
None
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.data.len(), Some(self.data.len()))
}
}
impl RecordBatchStream for MemoryStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}