use crate::spill::get_record_batch_memory_size;
use arrow::compute::interleave;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_execution::memory_pool::MemoryReservation;
use std::sync::Arc;
#[derive(Debug, Copy, Clone, Default)]
struct BatchCursor {
batch_idx: usize,
row_idx: usize,
}
#[derive(Debug)]
pub struct BatchBuilder {
schema: SchemaRef,
batches: Vec<(usize, RecordBatch)>,
reservation: MemoryReservation,
cursors: Vec<BatchCursor>,
indices: Vec<(usize, usize)>,
}
impl BatchBuilder {
pub fn new(
schema: SchemaRef,
stream_count: usize,
batch_size: usize,
reservation: MemoryReservation,
) -> Self {
Self {
schema,
batches: Vec::with_capacity(stream_count * 2),
cursors: vec![BatchCursor::default(); stream_count],
indices: Vec::with_capacity(batch_size),
reservation,
}
}
pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) -> Result<()> {
self.reservation
.try_grow(get_record_batch_memory_size(&batch))?;
let batch_idx = self.batches.len();
self.batches.push((stream_idx, batch));
self.cursors[stream_idx] = BatchCursor {
batch_idx,
row_idx: 0,
};
Ok(())
}
pub fn push_row(&mut self, stream_idx: usize) {
let cursor = &mut self.cursors[stream_idx];
let row_idx = cursor.row_idx;
cursor.row_idx += 1;
self.indices.push((cursor.batch_idx, row_idx));
}
pub fn len(&self) -> usize {
self.indices.len()
}
pub fn is_empty(&self) -> bool {
self.indices.is_empty()
}
pub fn schema(&self) -> &SchemaRef {
&self.schema
}
pub fn build_record_batch(&mut self) -> Result<Option<RecordBatch>> {
if self.is_empty() {
return Ok(None);
}
let columns = (0..self.schema.fields.len())
.map(|column_idx| {
let arrays: Vec<_> = self
.batches
.iter()
.map(|(_, batch)| batch.column(column_idx).as_ref())
.collect();
Ok(interleave(&arrays, &self.indices)?)
})
.collect::<Result<Vec<_>>>()?;
self.indices.clear();
let mut batch_idx = 0;
let mut retained = 0;
self.batches.retain(|(stream_idx, batch)| {
let stream_cursor = &mut self.cursors[*stream_idx];
let retain = stream_cursor.batch_idx == batch_idx;
batch_idx += 1;
if retain {
stream_cursor.batch_idx = retained;
retained += 1;
} else {
self.reservation.shrink(get_record_batch_memory_size(batch));
}
retain
});
Ok(Some(RecordBatch::try_new(
Arc::clone(&self.schema),
columns,
)?))
}
}