use std::io::Error;
use std::mem;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use crate::datasource::physical_plan::FileMeta;
use crate::error::Result;
use crate::physical_plan::SendableRecordBatchStream;
use arrow_array::RecordBatch;
use datafusion_common::{exec_err, internal_err, DataFusionError, FileCompressionType};
use async_trait::async_trait;
use bytes::Bytes;
use futures::future::BoxFuture;
use futures::FutureExt;
use futures::{ready, StreamExt};
use object_store::path::Path;
use object_store::{MultipartId, ObjectMeta, ObjectStore};
use tokio::io::{AsyncWrite, AsyncWriteExt};
pub struct AsyncPutWriter {
object_meta: ObjectMeta,
store: Arc<dyn ObjectStore>,
current_buffer: Vec<u8>,
inner_state: AsyncPutState,
}
impl AsyncPutWriter {
pub fn new(object_meta: ObjectMeta, store: Arc<dyn ObjectStore>) -> Self {
Self {
object_meta,
store,
current_buffer: vec![],
inner_state: AsyncPutState::Buffer,
}
}
fn poll_shutdown_inner(
&mut self,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Error>> {
loop {
match &mut self.inner_state {
AsyncPutState::Buffer => {
let bytes = Bytes::from(mem::take(&mut self.current_buffer));
self.inner_state = AsyncPutState::Put { bytes }
}
AsyncPutState::Put { bytes } => {
return Poll::Ready(
ready!(self
.store
.put(&self.object_meta.location, bytes.clone())
.poll_unpin(cx))
.map_err(Error::from),
);
}
}
}
}
}
enum AsyncPutState {
Buffer,
Put { bytes: Bytes },
}
impl AsyncWrite for AsyncPutWriter {
fn poll_write(
mut self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::result::Result<usize, Error>> {
self.current_buffer.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<std::result::Result<(), Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Error>> {
self.poll_shutdown_inner(cx)
}
}
pub(crate) struct MultiPart {
store: Arc<dyn ObjectStore>,
multipart_id: MultipartId,
location: Path,
}
impl MultiPart {
pub fn new(
store: Arc<dyn ObjectStore>,
multipart_id: MultipartId,
location: Path,
) -> Self {
Self {
store,
multipart_id,
location,
}
}
}
pub(crate) enum AbortMode {
Put,
Append,
MultiPart(MultiPart),
}
pub(crate) struct AbortableWrite<W: AsyncWrite + Unpin + Send> {
writer: W,
mode: AbortMode,
}
impl<W: AsyncWrite + Unpin + Send> AbortableWrite<W> {
pub(crate) fn new(writer: W, mode: AbortMode) -> Self {
Self { writer, mode }
}
pub(crate) fn abort_writer(&self) -> Result<BoxFuture<'static, Result<()>>> {
match &self.mode {
AbortMode::Put => Ok(async { Ok(()) }.boxed()),
AbortMode::Append => exec_err!("Cannot abort in append mode"),
AbortMode::MultiPart(MultiPart {
store,
multipart_id,
location,
}) => {
let location = location.clone();
let multipart_id = multipart_id.clone();
let store = store.clone();
Ok(Box::pin(async move {
store
.abort_multipart(&location, &multipart_id)
.await
.map_err(DataFusionError::ObjectStore)
}))
}
}
}
}
impl<W: AsyncWrite + Unpin + Send> AsyncWrite for AbortableWrite<W> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::result::Result<usize, Error>> {
Pin::new(&mut self.get_mut().writer).poll_write(cx, buf)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Error>> {
Pin::new(&mut self.get_mut().writer).poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Error>> {
Pin::new(&mut self.get_mut().writer).poll_shutdown(cx)
}
}
#[derive(Debug, Clone, Copy)]
pub enum FileWriterMode {
Append,
Put,
PutMultipart,
}
#[async_trait]
pub trait BatchSerializer: Unpin + Send {
async fn serialize(&mut self, batch: RecordBatch) -> Result<Bytes>;
}
async fn check_for_errors<T, W: AsyncWrite + Unpin + Send>(
result: Result<T>,
writers: &mut [AbortableWrite<W>],
) -> Result<T> {
match result {
Ok(value) => Ok(value),
Err(e) => {
for writer in writers {
let mut abort_future = writer.abort_writer();
if let Ok(abort_future) = &mut abort_future {
let _ = abort_future.await;
}
}
Err(e)
}
}
}
pub(crate) async fn create_writer(
writer_mode: FileWriterMode,
file_compression_type: FileCompressionType,
file_meta: FileMeta,
object_store: Arc<dyn ObjectStore>,
) -> Result<AbortableWrite<Box<dyn AsyncWrite + Send + Unpin>>> {
let object = &file_meta.object_meta;
match writer_mode {
FileWriterMode::Append => {
let writer = object_store
.append(&object.location)
.await
.map_err(DataFusionError::ObjectStore)?;
let writer = AbortableWrite::new(
file_compression_type.convert_async_writer(writer)?,
AbortMode::Append,
);
Ok(writer)
}
FileWriterMode::Put => {
let writer = Box::new(AsyncPutWriter::new(object.clone(), object_store));
let writer = AbortableWrite::new(
file_compression_type.convert_async_writer(writer)?,
AbortMode::Put,
);
Ok(writer)
}
FileWriterMode::PutMultipart => {
let (multipart_id, writer) = object_store
.put_multipart(&object.location)
.await
.map_err(DataFusionError::ObjectStore)?;
Ok(AbortableWrite::new(
file_compression_type.convert_async_writer(writer)?,
AbortMode::MultiPart(MultiPart::new(
object_store,
multipart_id,
object.location.clone(),
)),
))
}
}
}
pub(crate) async fn stateless_serialize_and_write_files(
mut data: Vec<SendableRecordBatchStream>,
mut serializers: Vec<Box<dyn BatchSerializer>>,
mut writers: Vec<AbortableWrite<Box<dyn AsyncWrite + Send + Unpin>>>,
single_file_output: bool,
) -> Result<u64> {
if single_file_output && (serializers.len() != 1 || writers.len() != 1) {
return internal_err!("single_file_output is true, but got more than 1 writer!");
}
let num_partitions = data.len();
if !single_file_output && (num_partitions != writers.len()) {
return internal_err!("single_file_ouput is false, but did not get 1 writer for each output partition!");
}
let mut row_count = 0;
let err_converter =
|_| DataFusionError::Internal("Unexpected FileSink Error".to_string());
for (part_idx, data_stream) in data.iter_mut().enumerate().take(num_partitions) {
let idx = match single_file_output {
false => part_idx,
true => 0,
};
while let Some(maybe_batch) = data_stream.next().await {
let serializer = &mut serializers[idx];
let batch = check_for_errors(maybe_batch, &mut writers).await?;
row_count += batch.num_rows();
let bytes =
check_for_errors(serializer.serialize(batch).await, &mut writers).await?;
let writer = &mut writers[idx];
check_for_errors(
writer.write_all(&bytes).await.map_err(err_converter),
&mut writers,
)
.await?;
}
}
let n_writers = writers.len();
for idx in 0..n_writers {
check_for_errors(
writers[idx].shutdown().await.map_err(err_converter),
&mut writers,
)
.await?;
}
Ok(row_count as u64)
}