use std::any::Any;
use std::fmt::Formatter;
use std::fs::{File, OpenOptions};
use std::io::BufReader;
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
use arrow_array::{RecordBatch, RecordBatchReader, RecordBatchWriter};
use arrow_schema::SchemaRef;
use async_trait::async_trait;
use futures::StreamExt;
use datafusion_common::{plan_err, Constraints, DataFusionError, Result};
use datafusion_common_runtime::SpawnedTask;
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_expr::{CreateExternalTable, Expr, TableType};
use datafusion_physical_plan::insert::{DataSink, DataSinkExec};
use datafusion_physical_plan::metrics::MetricsSet;
use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder;
use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec};
use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
use crate::datasource::provider::TableProviderFactory;
use crate::datasource::{create_ordering, TableProvider};
use crate::execution::context::SessionState;
#[derive(Debug, Default)]
pub struct StreamTableFactory {}
#[async_trait]
impl TableProviderFactory for StreamTableFactory {
async fn create(
&self,
state: &SessionState,
cmd: &CreateExternalTable,
) -> Result<Arc<dyn TableProvider>> {
let schema: SchemaRef = Arc::new(cmd.schema.as_ref().into());
let location = cmd.location.clone();
let encoding = cmd.file_type.parse()?;
let config = StreamConfig::new_file(schema, location.into())
.with_encoding(encoding)
.with_order(cmd.order_exprs.clone())
.with_header(cmd.has_header)
.with_batch_size(state.config().batch_size())
.with_constraints(cmd.constraints.clone());
Ok(Arc::new(StreamTable(Arc::new(config))))
}
}
#[derive(Debug, Clone)]
pub enum StreamEncoding {
Csv,
Json,
}
impl FromStr for StreamEncoding {
type Err = DataFusionError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"csv" => Ok(Self::Csv),
"json" => Ok(Self::Json),
_ => plan_err!("Unrecognised StreamEncoding {}", s),
}
}
}
#[derive(Debug)]
pub struct StreamConfig {
schema: SchemaRef,
location: PathBuf,
batch_size: usize,
encoding: StreamEncoding,
header: bool,
order: Vec<Vec<Expr>>,
constraints: Constraints,
}
impl StreamConfig {
pub fn new_file(schema: SchemaRef, location: PathBuf) -> Self {
Self {
schema,
location,
batch_size: 1024,
encoding: StreamEncoding::Csv,
order: vec![],
header: false,
constraints: Constraints::empty(),
}
}
pub fn with_order(mut self, order: Vec<Vec<Expr>>) -> Self {
self.order = order;
self
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
pub fn with_header(mut self, header: bool) -> Self {
self.header = header;
self
}
pub fn with_encoding(mut self, encoding: StreamEncoding) -> Self {
self.encoding = encoding;
self
}
pub fn with_constraints(mut self, constraints: Constraints) -> Self {
self.constraints = constraints;
self
}
fn reader(&self) -> Result<Box<dyn RecordBatchReader>> {
let file = File::open(&self.location)?;
let schema = self.schema.clone();
match &self.encoding {
StreamEncoding::Csv => {
let reader = arrow::csv::ReaderBuilder::new(schema)
.with_header(self.header)
.with_batch_size(self.batch_size)
.build(file)?;
Ok(Box::new(reader))
}
StreamEncoding::Json => {
let reader = arrow::json::ReaderBuilder::new(schema)
.with_batch_size(self.batch_size)
.build(BufReader::new(file))?;
Ok(Box::new(reader))
}
}
}
fn writer(&self) -> Result<Box<dyn RecordBatchWriter>> {
match &self.encoding {
StreamEncoding::Csv => {
let header = self.header && !self.location.exists();
let file = OpenOptions::new()
.create(true)
.append(true)
.open(&self.location)?;
let writer = arrow::csv::WriterBuilder::new()
.with_header(header)
.build(file);
Ok(Box::new(writer))
}
StreamEncoding::Json => {
let file = OpenOptions::new()
.create(true)
.append(true)
.open(&self.location)?;
Ok(Box::new(arrow::json::LineDelimitedWriter::new(file)))
}
}
}
}
pub struct StreamTable(Arc<StreamConfig>);
impl StreamTable {
pub fn new(config: Arc<StreamConfig>) -> Self {
Self(config)
}
}
#[async_trait]
impl TableProvider for StreamTable {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.0.schema.clone()
}
fn constraints(&self) -> Option<&Constraints> {
Some(&self.0.constraints)
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
_state: &SessionState,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let projected_schema = match projection {
Some(p) => {
let projected = self.0.schema.project(p)?;
create_ordering(&projected, &self.0.order)?
}
None => create_ordering(self.0.schema.as_ref(), &self.0.order)?,
};
Ok(Arc::new(StreamingTableExec::try_new(
self.0.schema.clone(),
vec![Arc::new(StreamRead(self.0.clone())) as _],
projection,
projected_schema,
true,
limit,
)?))
}
async fn insert_into(
&self,
_state: &SessionState,
input: Arc<dyn ExecutionPlan>,
_overwrite: bool,
) -> Result<Arc<dyn ExecutionPlan>> {
let ordering = match self.0.order.first() {
Some(x) => {
let schema = self.0.schema.as_ref();
let orders = create_ordering(schema, std::slice::from_ref(x))?;
let ordering = orders.into_iter().next().unwrap();
Some(ordering.into_iter().map(Into::into).collect())
}
None => None,
};
Ok(Arc::new(DataSinkExec::new(
input,
Arc::new(StreamWrite(self.0.clone())),
self.0.schema.clone(),
ordering,
)))
}
}
struct StreamRead(Arc<StreamConfig>);
impl PartitionStream for StreamRead {
fn schema(&self) -> &SchemaRef {
&self.0.schema
}
fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
let config = self.0.clone();
let schema = self.0.schema.clone();
let mut builder = RecordBatchReceiverStreamBuilder::new(schema, 2);
let tx = builder.tx();
builder.spawn_blocking(move || {
let reader = config.reader()?;
for b in reader {
if tx.blocking_send(b.map_err(Into::into)).is_err() {
break;
}
}
Ok(())
});
builder.build()
}
}
#[derive(Debug)]
struct StreamWrite(Arc<StreamConfig>);
impl DisplayAs for StreamWrite {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
f.debug_struct("StreamWrite")
.field("location", &self.0.location)
.field("batch_size", &self.0.batch_size)
.field("encoding", &self.0.encoding)
.field("header", &self.0.header)
.finish_non_exhaustive()
}
}
#[async_trait]
impl DataSink for StreamWrite {
fn as_any(&self) -> &dyn Any {
self
}
fn metrics(&self) -> Option<MetricsSet> {
None
}
async fn write_all(
&self,
mut data: SendableRecordBatchStream,
_context: &Arc<TaskContext>,
) -> Result<u64> {
let config = self.0.clone();
let (sender, mut receiver) = tokio::sync::mpsc::channel::<RecordBatch>(2);
let write_task = SpawnedTask::spawn_blocking(move || {
let mut count = 0_u64;
let mut writer = config.writer()?;
while let Some(batch) = receiver.blocking_recv() {
count += batch.num_rows() as u64;
writer.write(&batch)?;
}
Ok(count)
});
while let Some(b) = data.next().await.transpose()? {
if sender.send(b).await.is_err() {
break;
}
}
drop(sender);
write_task.join_unwind().await
}
}