use std::sync::Arc;
use crate::datasource::file_format::file_compression_type::FileCompressionType;
use crate::datasource::listing::PartitionedFile;
use crate::datasource::physical_plan::FileSinkConfig;
use crate::error::Result;
use crate::physical_plan::SendableRecordBatchStream;
use arrow_array::RecordBatch;
use datafusion_common::DataFusionError;
use bytes::Bytes;
use datafusion_execution::TaskContext;
use futures::StreamExt;
use object_store::{ObjectMeta, ObjectStore};
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc::{self, Receiver};
use tokio::task::{JoinHandle, JoinSet};
use tokio::try_join;
use super::demux::start_demuxer_task;
use super::{create_writer, AbortableWrite, BatchSerializer, FileWriterMode};
type WriterType = AbortableWrite<Box<dyn AsyncWrite + Send + Unpin>>;
type SerializerType = Box<dyn BatchSerializer>;
pub(crate) async fn serialize_rb_stream_to_object_store(
mut data_rx: Receiver<RecordBatch>,
mut serializer: Box<dyn BatchSerializer>,
mut writer: AbortableWrite<Box<dyn AsyncWrite + Send + Unpin>>,
unbounded_input: bool,
) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> {
let (tx, mut rx) =
mpsc::channel::<JoinHandle<Result<(usize, Bytes), DataFusionError>>>(100);
let serialize_task = tokio::spawn(async move {
while let Some(batch) = data_rx.recv().await {
match serializer.duplicate() {
Ok(mut serializer_clone) => {
let handle = tokio::spawn(async move {
let num_rows = batch.num_rows();
let bytes = serializer_clone.serialize(batch).await?;
Ok((num_rows, bytes))
});
tx.send(handle).await.map_err(|_| {
DataFusionError::Internal(
"Unknown error writing to object store".into(),
)
})?;
if unbounded_input {
tokio::task::yield_now().await;
}
}
Err(_) => {
return Err(DataFusionError::Internal(
"Unknown error writing to object store".into(),
))
}
}
}
Ok(())
});
let mut row_count = 0;
while let Some(handle) = rx.recv().await {
match handle.await {
Ok(Ok((cnt, bytes))) => {
match writer.write_all(&bytes).await {
Ok(_) => (),
Err(e) => {
return Err((
writer,
DataFusionError::Execution(format!(
"Error writing to object store: {e}"
)),
))
}
};
row_count += cnt;
}
Ok(Err(e)) => {
return Err((writer, e));
}
Err(e) => {
return Err((
writer,
DataFusionError::Execution(format!(
"Serialization task panicked or was cancelled: {e}"
)),
));
}
}
}
match serialize_task.await {
Ok(Ok(_)) => (),
Ok(Err(e)) => return Err((writer, e)),
Err(_) => {
return Err((
writer,
DataFusionError::Internal("Unknown error writing to object store".into()),
))
}
};
Ok((writer, row_count as u64))
}
type FileWriteBundle = (Receiver<RecordBatch>, SerializerType, WriterType);
pub(crate) async fn stateless_serialize_and_write_files(
mut rx: Receiver<FileWriteBundle>,
tx: tokio::sync::oneshot::Sender<u64>,
unbounded_input: bool,
) -> Result<()> {
let mut row_count = 0;
let mut any_errors = false;
let mut triggering_error = None;
let mut any_abort_errors = false;
let mut join_set = JoinSet::new();
while let Some((data_rx, serializer, writer)) = rx.recv().await {
join_set.spawn(async move {
serialize_rb_stream_to_object_store(
data_rx,
serializer,
writer,
unbounded_input,
)
.await
});
}
let mut finished_writers = Vec::new();
while let Some(result) = join_set.join_next().await {
match result {
Ok(res) => match res {
Ok((writer, cnt)) => {
finished_writers.push(writer);
row_count += cnt;
}
Err((writer, e)) => {
finished_writers.push(writer);
any_errors = true;
triggering_error = Some(e);
}
},
Err(e) => {
any_errors = true;
any_abort_errors = true;
triggering_error = Some(DataFusionError::Internal(format!(
"Unexpected join error while serializing file {e}"
)));
}
}
}
for mut writer in finished_writers.into_iter() {
match any_errors {
true => {
let abort_result = writer.abort_writer();
if abort_result.is_err() {
any_abort_errors = true;
}
}
false => {
writer.shutdown()
.await
.map_err(|_| DataFusionError::Internal("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!".into()))?;
}
}
}
if any_errors {
match any_abort_errors{
true => return Err(DataFusionError::Internal("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written.".into())),
false => match triggering_error {
Some(e) => return Err(e),
None => return Err(DataFusionError::Internal("Unknown Error encountered during writing to ObjectStore. All writers succesfully aborted.".into()))
}
}
}
tx.send(row_count).map_err(|_| {
DataFusionError::Internal(
"Error encountered while sending row count back to file sink!".into(),
)
})?;
Ok(())
}
pub(crate) async fn stateless_multipart_put(
data: SendableRecordBatchStream,
context: &Arc<TaskContext>,
file_extension: String,
get_serializer: Box<dyn Fn() -> Box<dyn BatchSerializer> + Send>,
config: &FileSinkConfig,
compression: FileCompressionType,
) -> Result<u64> {
let object_store = context
.runtime_env()
.object_store(&config.object_store_url)?;
let single_file_output = config.single_file_output;
let base_output_path = &config.table_paths[0];
let unbounded_input = config.unbounded_input;
let part_cols = if !config.table_partition_cols.is_empty() {
Some(config.table_partition_cols.clone())
} else {
None
};
let (demux_task, mut file_stream_rx) = start_demuxer_task(
data,
context,
part_cols,
base_output_path.clone(),
file_extension,
single_file_output,
);
let rb_buffer_size = &context
.session_config()
.options()
.execution
.max_buffered_batches_per_output_file;
let (tx_file_bundle, rx_file_bundle) = tokio::sync::mpsc::channel(rb_buffer_size / 2);
let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel();
let write_coordinater_task = tokio::spawn(async move {
stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt, unbounded_input)
.await
});
while let Some((output_location, rb_stream)) = file_stream_rx.recv().await {
let serializer = get_serializer();
let object_meta = ObjectMeta {
location: output_location,
last_modified: chrono::offset::Utc::now(),
size: 0,
e_tag: None,
};
let writer = create_writer(
FileWriterMode::PutMultipart,
compression,
object_meta.into(),
object_store.clone(),
)
.await?;
tx_file_bundle
.send((rb_stream, serializer, writer))
.await
.map_err(|_| {
DataFusionError::Internal(
"Writer receive file bundle channel closed unexpectedly!".into(),
)
})?;
}
drop(tx_file_bundle);
match try_join!(write_coordinater_task, demux_task) {
Ok((r1, r2)) => {
r1?;
r2?;
}
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
}
}
let total_count = rx_row_cnt.await.map_err(|_| {
DataFusionError::Internal(
"Did not receieve row count from write coordinater".into(),
)
})?;
Ok(total_count)
}
pub(crate) async fn stateless_append_all(
mut data: SendableRecordBatchStream,
context: &Arc<TaskContext>,
object_store: Arc<dyn ObjectStore>,
file_groups: &Vec<PartitionedFile>,
unbounded_input: bool,
compression: FileCompressionType,
get_serializer: Box<dyn Fn(usize) -> Box<dyn BatchSerializer> + Send>,
) -> Result<u64> {
let rb_buffer_size = &context
.session_config()
.options()
.execution
.max_buffered_batches_per_output_file;
let (tx_file_bundle, rx_file_bundle) = tokio::sync::mpsc::channel(file_groups.len());
let mut send_channels = vec![];
for file_group in file_groups {
let serializer = get_serializer(file_group.object_meta.size);
let file = file_group.clone();
let writer = create_writer(
FileWriterMode::Append,
compression,
file.object_meta.clone().into(),
object_store.clone(),
)
.await?;
let (tx, rx) = tokio::sync::mpsc::channel(rb_buffer_size / 2);
send_channels.push(tx);
tx_file_bundle
.send((rx, serializer, writer))
.await
.map_err(|_| {
DataFusionError::Internal(
"Writer receive file bundle channel closed unexpectedly!".into(),
)
})?;
}
let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel();
let write_coordinater_task = tokio::spawn(async move {
stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt, unbounded_input)
.await
});
let mut next_file_idx = 0;
while let Some(rb) = data.next().await.transpose()? {
send_channels[next_file_idx].send(rb).await.map_err(|_| {
DataFusionError::Internal(
"Recordbatch file append stream closed unexpectedly!".into(),
)
})?;
next_file_idx = (next_file_idx + 1) % send_channels.len();
if unbounded_input {
tokio::task::yield_now().await;
}
}
drop(tx_file_bundle);
drop(send_channels);
let total_count = rx_row_cnt.await.map_err(|_| {
DataFusionError::Internal(
"Did not receieve row count from write coordinater".into(),
)
})?;
match try_join!(write_coordinater_task) {
Ok(r1) => {
r1.0?;
}
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
}
}
Ok(total_count)
}