use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::Arc;
use arrow_array::{RecordBatch, RecordBatchReader};
use datafusion::error::DataFusionError;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::SendableRecordBatchStream;
use futures::{Stream, StreamExt, TryStreamExt};
use object_store::path::Path;
use uuid::Uuid;
use crate::error::Result;
use crate::Error;
use crate::{
datatypes::Schema,
format::Fragment,
io::{object_store::ObjectStoreParams, FileWriter, ObjectStore},
};
use super::DATA_DIR;
#[derive(Debug, Clone, Copy)]
pub enum WriteMode {
Create,
Append,
Overwrite,
}
#[derive(Debug, Clone)]
pub struct WriteParams {
pub max_rows_per_file: usize,
pub max_rows_per_group: usize,
pub mode: WriteMode,
pub store_params: Option<ObjectStoreParams>,
}
impl Default for WriteParams {
fn default() -> Self {
Self {
max_rows_per_file: 1024 * 1024, max_rows_per_group: 1024,
mode: WriteMode::Create,
store_params: None,
}
}
}
pub fn reader_to_stream(
batches: Box<dyn RecordBatchReader + Send>,
) -> Result<(SendableRecordBatchStream, Schema)> {
let arrow_schema = batches.schema();
let mut schema: Schema = Schema::try_from(batches.schema().as_ref())?;
let mut peekable = batches.peekable();
if let Some(batch) = peekable.peek() {
if let Ok(b) = batch {
schema.set_dictionary(b)?;
} else {
return Err(Error::from(batch.as_ref().unwrap_err()));
}
}
schema.validate()?;
let stream = RecordBatchStreamAdapter::new(
arrow_schema,
futures::stream::iter(peekable).map_err(DataFusionError::from),
);
let stream = Box::pin(stream) as SendableRecordBatchStream;
Ok((stream, schema))
}
pub async fn write_fragments(
object_store: Arc<ObjectStore>,
base_dir: &Path,
schema: &Schema,
data: SendableRecordBatchStream,
mut params: WriteParams,
) -> Result<Vec<Fragment>> {
params.max_rows_per_group = std::cmp::min(params.max_rows_per_group, params.max_rows_per_file);
let mut buffered_reader = chunk_stream(data, params.max_rows_per_group);
let writer_generator = WriterGenerator::new(object_store, base_dir, schema);
let mut writer: Option<FileWriter> = None;
let mut num_rows_in_current_file = 0;
let mut fragments = Vec::new();
while let Some(batch_chunk) = buffered_reader.next().await {
let batch_chunk = batch_chunk?;
if writer.is_none() {
let (new_writer, new_fragment) = writer_generator.new_writer().await?;
writer = Some(new_writer);
fragments.push(new_fragment);
}
writer.as_mut().unwrap().write(&batch_chunk).await?;
for batch in batch_chunk {
num_rows_in_current_file += batch.num_rows();
}
if num_rows_in_current_file >= params.max_rows_per_file {
writer.take().unwrap().finish().await?;
num_rows_in_current_file = 0;
}
}
if let Some(mut writer) = writer.take() {
writer.finish().await?;
}
Ok(fragments)
}
struct WriterGenerator {
object_store: Arc<ObjectStore>,
base_dir: Path,
schema: Schema,
}
impl WriterGenerator {
pub fn new(object_store: Arc<ObjectStore>, base_dir: &Path, schema: &Schema) -> Self {
Self {
object_store,
base_dir: base_dir.clone(),
schema: schema.clone(),
}
}
pub async fn new_writer(&self) -> Result<(FileWriter, Fragment)> {
let data_file_path = format!("{}.lance", Uuid::new_v4());
let mut fragment = Fragment::new(0);
fragment.add_file(&data_file_path, &self.schema);
let full_path = self.base_dir.child(DATA_DIR).child(data_file_path);
let writer =
FileWriter::try_new(self.object_store.as_ref(), &full_path, self.schema.clone())
.await?;
Ok((writer, fragment))
}
}
fn chunk_stream(
stream: SendableRecordBatchStream,
chunk_size: usize,
) -> Pin<Box<dyn Stream<Item = Result<Vec<RecordBatch>>> + Send>> {
let chunker = BatchReaderChunker::new(stream, chunk_size);
futures::stream::unfold(chunker, |mut chunker| async move {
match chunker.next().await {
Some(Ok(batches)) => Some((Ok(batches), chunker)),
Some(Err(e)) => Some((Err(e), chunker)),
None => None,
}
})
.boxed()
}
struct BatchReaderChunker {
inner: SendableRecordBatchStream,
buffered: VecDeque<RecordBatch>,
output_size: usize,
i: usize,
}
impl BatchReaderChunker {
pub fn new(inner: SendableRecordBatchStream, output_size: usize) -> Self {
Self {
inner,
buffered: VecDeque::new(),
output_size,
i: 0,
}
}
fn buffered_len(&self) -> usize {
let buffer_total: usize = self.buffered.iter().map(|batch| batch.num_rows()).sum();
buffer_total - self.i
}
async fn fill_buffer(&mut self) -> Result<()> {
while self.buffered_len() < self.output_size {
match self.inner.next().await {
Some(Ok(batch)) => self.buffered.push_back(batch),
Some(Err(e)) => return Err(e.into()),
None => break,
}
}
Ok(())
}
async fn next(&mut self) -> Option<Result<Vec<RecordBatch>>> {
match self.fill_buffer().await {
Ok(_) => {}
Err(e) => return Some(Err(e)),
};
let mut batches = Vec::new();
let mut rows_collected = 0;
while rows_collected < self.output_size {
if let Some(batch) = self.buffered.pop_front() {
let rows_remaining_in_batch = batch.num_rows() - self.i;
let rows_to_take =
std::cmp::min(rows_remaining_in_batch, self.output_size - rows_collected);
if rows_to_take == rows_remaining_in_batch {
let batch = if self.i == 0 {
batch
} else {
batch.slice(self.i, rows_to_take)
};
batches.push(batch);
self.i = 0;
} else {
batches.push(batch.slice(self.i, rows_to_take));
self.i += rows_to_take;
self.buffered.push_front(batch);
}
rows_collected += rows_to_take;
} else {
break;
}
}
if batches.is_empty() {
None
} else {
Some(Ok(batches))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::Int32Array;
use arrow_schema::DataType;
use arrow_schema::Schema as ArrowSchema;
#[tokio::test]
async fn test_chunking_large_batches() {
let schema = Arc::new(ArrowSchema::new(vec![arrow::datatypes::Field::new(
"a",
DataType::Int32,
false,
)]));
let batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from_iter(0..28))])
.unwrap();
let batches: Vec<RecordBatch> =
vec![batch.slice(0, 10), batch.slice(10, 10), batch.slice(20, 8)];
let stream = RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::iter(batches.into_iter().map(Ok::<_, DataFusionError>)),
);
let chunks: Vec<Vec<RecordBatch>> = chunk_stream(Box::pin(stream), 3)
.try_collect()
.await
.unwrap();
assert_eq!(chunks.len(), 10);
assert_eq!(chunks[0].len(), 1);
for (i, chunk) in chunks.iter().enumerate() {
let num_rows = chunk.iter().map(|batch| batch.num_rows()).sum::<usize>();
if i < chunks.len() - 1 {
assert_eq!(num_rows, 3);
} else {
assert_eq!(num_rows, 1);
}
}
assert_eq!(chunks[3].len(), 2);
assert_eq!(chunks[3][0].num_rows(), 1);
assert_eq!(chunks[3][1].num_rows(), 2);
}
#[tokio::test]
async fn test_chunking_small_batches() {
let schema = Arc::new(ArrowSchema::new(vec![arrow::datatypes::Field::new(
"a",
DataType::Int32,
false,
)]));
let batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from_iter(0..30))])
.unwrap();
let batches: Vec<RecordBatch> = (0..10).map(|i| batch.slice(i * 3, 3)).collect();
let stream = RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::iter(batches.into_iter().map(Ok::<_, DataFusionError>)),
);
let chunks: Vec<Vec<RecordBatch>> = chunk_stream(Box::pin(stream), 10)
.try_collect()
.await
.unwrap();
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0].len(), 4);
assert_eq!(chunks[0][0], batch.slice(0, 3));
assert_eq!(chunks[0][1], batch.slice(3, 3));
assert_eq!(chunks[0][2], batch.slice(6, 3));
assert_eq!(chunks[0][3], batch.slice(9, 1));
for chunk in &chunks {
let num_rows = chunk.iter().map(|batch| batch.num_rows()).sum::<usize>();
assert_eq!(num_rows, 10);
}
}
}