use std::io::Error;
use std::mem;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use crate::datasource::file_format::file_compression_type::FileCompressionType;
use crate::datasource::physical_plan::FileMeta;
use crate::error::Result;
use arrow_array::RecordBatch;
use datafusion_common::{exec_err, DataFusionError};
use async_trait::async_trait;
use bytes::Bytes;
use futures::future::BoxFuture;
use futures::ready;
use futures::FutureExt;
use object_store::path::Path;
use object_store::{MultipartId, ObjectMeta, ObjectStore};
use tokio::io::AsyncWrite;
pub(crate) mod demux;
pub(crate) mod orchestration;
pub struct AsyncPutWriter {
object_meta: ObjectMeta,
store: Arc<dyn ObjectStore>,
current_buffer: Vec<u8>,
inner_state: AsyncPutState,
}
impl AsyncPutWriter {
pub fn new(object_meta: ObjectMeta, store: Arc<dyn ObjectStore>) -> Self {
Self {
object_meta,
store,
current_buffer: vec![],
inner_state: AsyncPutState::Buffer,
}
}
fn poll_shutdown_inner(
&mut self,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Error>> {
loop {
match &mut self.inner_state {
AsyncPutState::Buffer => {
let bytes = Bytes::from(mem::take(&mut self.current_buffer));
self.inner_state = AsyncPutState::Put { bytes }
}
AsyncPutState::Put { bytes } => {
return Poll::Ready(
ready!(self
.store
.put(&self.object_meta.location, bytes.clone())
.poll_unpin(cx))
.map_err(Error::from),
);
}
}
}
}
}
enum AsyncPutState {
Buffer,
Put { bytes: Bytes },
}
impl AsyncWrite for AsyncPutWriter {
fn poll_write(
mut self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::result::Result<usize, Error>> {
self.current_buffer.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<std::result::Result<(), Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Error>> {
self.poll_shutdown_inner(cx)
}
}
pub(crate) struct MultiPart {
store: Arc<dyn ObjectStore>,
multipart_id: MultipartId,
location: Path,
}
impl MultiPart {
pub fn new(
store: Arc<dyn ObjectStore>,
multipart_id: MultipartId,
location: Path,
) -> Self {
Self {
store,
multipart_id,
location,
}
}
}
pub(crate) enum AbortMode {
Put,
Append,
MultiPart(MultiPart),
}
pub(crate) struct AbortableWrite<W: AsyncWrite + Unpin + Send> {
writer: W,
mode: AbortMode,
}
impl<W: AsyncWrite + Unpin + Send> AbortableWrite<W> {
pub(crate) fn new(writer: W, mode: AbortMode) -> Self {
Self { writer, mode }
}
pub(crate) fn abort_writer(&self) -> Result<BoxFuture<'static, Result<()>>> {
match &self.mode {
AbortMode::Put => Ok(async { Ok(()) }.boxed()),
AbortMode::Append => exec_err!("Cannot abort in append mode"),
AbortMode::MultiPart(MultiPart {
store,
multipart_id,
location,
}) => {
let location = location.clone();
let multipart_id = multipart_id.clone();
let store = store.clone();
Ok(Box::pin(async move {
store
.abort_multipart(&location, &multipart_id)
.await
.map_err(DataFusionError::ObjectStore)
}))
}
}
}
}
impl<W: AsyncWrite + Unpin + Send> AsyncWrite for AbortableWrite<W> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::result::Result<usize, Error>> {
Pin::new(&mut self.get_mut().writer).poll_write(cx, buf)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Error>> {
Pin::new(&mut self.get_mut().writer).poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Error>> {
Pin::new(&mut self.get_mut().writer).poll_shutdown(cx)
}
}
#[derive(Debug, Clone, Copy)]
pub enum FileWriterMode {
Append,
Put,
PutMultipart,
}
#[async_trait]
pub trait BatchSerializer: Unpin + Send {
async fn serialize(&mut self, batch: RecordBatch) -> Result<Bytes>;
fn duplicate(&mut self) -> Result<Box<dyn BatchSerializer>> {
Err(DataFusionError::NotImplemented(
"Parallel serialization is not implemented for this file type".into(),
))
}
}
pub(crate) async fn create_writer(
writer_mode: FileWriterMode,
file_compression_type: FileCompressionType,
file_meta: FileMeta,
object_store: Arc<dyn ObjectStore>,
) -> Result<AbortableWrite<Box<dyn AsyncWrite + Send + Unpin>>> {
let object = &file_meta.object_meta;
match writer_mode {
FileWriterMode::Append => {
let writer = object_store
.append(&object.location)
.await
.map_err(DataFusionError::ObjectStore)?;
let writer = AbortableWrite::new(
file_compression_type.convert_async_writer(writer)?,
AbortMode::Append,
);
Ok(writer)
}
FileWriterMode::Put => {
let writer = Box::new(AsyncPutWriter::new(object.clone(), object_store));
let writer = AbortableWrite::new(
file_compression_type.convert_async_writer(writer)?,
AbortMode::Put,
);
Ok(writer)
}
FileWriterMode::PutMultipart => {
let (multipart_id, writer) = object_store
.put_multipart(&object.location)
.await
.map_err(DataFusionError::ObjectStore)?;
Ok(AbortableWrite::new(
file_compression_type.convert_async_writer(writer)?,
AbortMode::MultiPart(MultiPart::new(
object_store,
multipart_id,
object.location.clone(),
)),
))
}
}
}