use ahash::AHashMap;
use std::{pin::Pin, task::Poll};
use futures::{future::BoxFuture, AsyncWrite, AsyncWriteExt, FutureExt, Sink, TryFutureExt};
use parquet2::metadata::KeyValue;
use parquet2::write::FileStreamer;
use parquet2::write::WriteOptions as ParquetWriteOptions;
use crate::{array::Array, chunk::Chunk, datatypes::Schema, error::Error};
use super::file::add_arrow_schema;
use super::{Encoding, SchemaDescriptor, WriteOptions};
pub struct FileSink<'a, W: AsyncWrite + Send + Unpin> {
writer: Option<FileStreamer<W>>,
task: Option<BoxFuture<'a, Result<Option<FileStreamer<W>>, Error>>>,
options: WriteOptions,
encodings: Vec<Vec<Encoding>>,
schema: Schema,
parquet_schema: SchemaDescriptor,
pub metadata: AHashMap<String, Option<String>>,
}
impl<'a, W> FileSink<'a, W>
where
W: AsyncWrite + Send + Unpin + 'a,
{
pub fn try_new(
writer: W,
schema: Schema,
encodings: Vec<Vec<Encoding>>,
options: WriteOptions,
) -> Result<Self, Error> {
if encodings.len() != schema.fields.len() {
return Err(Error::InvalidArgumentError(
"The number of encodings must equal the number of fields".to_string(),
));
}
let parquet_schema = crate::io::parquet::write::to_parquet_schema(&schema)?;
let created_by = Some("Arrow2 - Native Rust implementation of Arrow".to_string());
let writer = FileStreamer::new(
writer,
parquet_schema.clone(),
ParquetWriteOptions {
version: options.version,
write_statistics: options.write_statistics,
},
created_by,
);
Ok(Self {
writer: Some(writer),
task: None,
options,
schema,
encodings,
parquet_schema,
metadata: AHashMap::default(),
})
}
pub fn schema(&self) -> &Schema {
&self.schema
}
pub fn parquet_schema(&self) -> &SchemaDescriptor {
&self.parquet_schema
}
pub fn options(&self) -> &WriteOptions {
&self.options
}
fn poll_complete(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Error>> {
if let Some(task) = &mut self.task {
match futures::ready!(task.poll_unpin(cx)) {
Ok(writer) => {
self.task = None;
self.writer = writer;
Poll::Ready(Ok(()))
}
Err(error) => {
self.task = None;
Poll::Ready(Err(error))
}
}
} else {
Poll::Ready(Ok(()))
}
}
}
impl<'a, W> Sink<Chunk<Box<dyn Array>>> for FileSink<'a, W>
where
W: AsyncWrite + Send + Unpin + 'a,
{
type Error = Error;
fn start_send(self: Pin<&mut Self>, item: Chunk<Box<dyn Array>>) -> Result<(), Self::Error> {
if self.schema.fields.len() != item.arrays().len() {
return Err(Error::InvalidArgumentError(
"The number of arrays in the chunk must equal the number of fields in the schema"
.to_string(),
));
}
let this = self.get_mut();
if let Some(mut writer) = this.writer.take() {
let rows = crate::io::parquet::write::row_group_iter(
item,
this.encodings.clone(),
this.parquet_schema.fields().to_vec(),
this.options,
);
this.task = Some(Box::pin(async move {
writer.write(rows).await?;
Ok(Some(writer))
}));
Ok(())
} else {
Err(Error::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"writer closed".to_string(),
)))
}
}
fn poll_ready(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.get_mut().poll_complete(cx)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.get_mut().poll_complete(cx)
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
let this = self.get_mut();
match futures::ready!(this.poll_complete(cx)) {
Ok(()) => {
let writer = this.writer.take();
if let Some(mut writer) = writer {
let meta = std::mem::take(&mut this.metadata);
let metadata = if meta.is_empty() {
None
} else {
Some(
meta.into_iter()
.map(|(k, v)| KeyValue::new(k, v))
.collect::<Vec<_>>(),
)
};
let kv_meta = add_arrow_schema(&this.schema, metadata);
this.task = Some(Box::pin(async move {
writer.end(kv_meta).map_err(Error::from).await?;
writer.into_inner().close().map_err(Error::from).await?;
Ok(None)
}));
this.poll_complete(cx)
} else {
Poll::Ready(Ok(()))
}
}
Err(error) => Poll::Ready(Err(error)),
}
}
}