use std::any::Any;
use std::fmt;
use std::sync::Arc;
use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef;
use datafusion::datasource::sink::DataSink;
use datafusion::error::Result as DFResult;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::execution::TaskContext;
use datafusion::physical_plan::DisplayAs;
use futures::StreamExt;
use paimon::table::Table;
use crate::error::to_datafusion_error;
#[derive(Debug)]
pub struct PaimonDataSink {
table: Table,
schema: ArrowSchemaRef,
overwrite: bool,
}
impl PaimonDataSink {
pub fn new(table: Table, schema: ArrowSchemaRef, overwrite: bool) -> Self {
Self {
table,
schema,
overwrite,
}
}
}
impl DisplayAs for PaimonDataSink {
fn fmt_as(
&self,
_t: datafusion::physical_plan::DisplayFormatType,
f: &mut fmt::Formatter,
) -> fmt::Result {
write!(f, "PaimonDataSink: table={}", self.table.identifier())
}
}
#[async_trait]
impl DataSink for PaimonDataSink {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> &ArrowSchemaRef {
&self.schema
}
async fn write_all(
&self,
mut data: SendableRecordBatchStream,
_context: &Arc<TaskContext>,
) -> DFResult<u64> {
let wb = self.table.new_write_builder();
let mut tw = wb.new_write().map_err(to_datafusion_error)?;
let mut row_count = 0u64;
while let Some(batch) = data.next().await {
let batch = batch?;
row_count += batch.num_rows() as u64;
tw.write_arrow_batch(&batch)
.await
.map_err(to_datafusion_error)?;
}
let messages = tw.prepare_commit().await.map_err(to_datafusion_error)?;
let commit = wb.new_commit();
if self.overwrite {
commit
.overwrite(messages)
.await
.map_err(to_datafusion_error)?;
} else {
commit.commit(messages).await.map_err(to_datafusion_error)?;
}
Ok(row_count)
}
}