use std::ops::Range;
use std::sync::atomic::AtomicU64;
use std::sync::{Arc, Mutex};
use arrow::compute::concat_batches;
use arrow::datatypes::UInt64Type;
use arrow::{array::AsArray, compute::sort_to_indices};
use arrow_array::{RecordBatch, UInt32Array, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use futures::{future::try_join_all, prelude::*};
use lance_arrow::stream::rechunk_stream_by_size;
use lance_arrow::{RecordBatchExt, SchemaExt};
use lance_core::{
Error, Result,
cache::LanceCache,
utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu},
};
use lance_encoding::decoder::{DecoderPlugins, FilterExpression};
use lance_encoding::version::LanceFileVersion;
use lance_file::reader::{FileReader, FileReaderOptions};
use lance_file::writer::{FileWriter, FileWriterOptions};
use lance_io::{
ReadBatchParams,
object_store::ObjectStore,
scheduler::{ScanScheduler, SchedulerConfig},
stream::{RecordBatchStream, RecordBatchStreamAdapter},
utils::CachedFileSize,
};
use object_store::path::Path;
use crate::vector::{LOSS_METADATA_KEY, PART_ID_COLUMN};
#[async_trait::async_trait]
pub trait ShuffleReader: Send + Sync {
async fn read_partition(
&self,
partition_id: usize,
) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>>;
fn partition_size(&self, partition_id: usize) -> Result<usize>;
fn total_loss(&self) -> Option<f64>;
}
#[async_trait::async_trait]
pub trait Shuffler: Send + Sync {
async fn shuffle(
&self,
data: Box<dyn RecordBatchStream + Unpin + 'static>,
) -> Result<Box<dyn ShuffleReader>>;
}
pub struct IvfShuffler {
object_store: Arc<ObjectStore>,
output_dir: Path,
num_partitions: usize,
format_version: LanceFileVersion,
progress: Arc<dyn crate::progress::IndexBuildProgress>,
}
impl IvfShuffler {
pub fn new(output_dir: Path, num_partitions: usize) -> Self {
Self {
object_store: Arc::new(ObjectStore::local()),
output_dir,
num_partitions,
format_version: LanceFileVersion::V2_0,
progress: crate::progress::noop_progress(),
}
}
pub fn with_format_version(mut self, format_version: LanceFileVersion) -> Self {
self.format_version = format_version;
self
}
pub fn with_progress(mut self, progress: Arc<dyn crate::progress::IndexBuildProgress>) -> Self {
self.progress = progress;
self
}
}
#[async_trait::async_trait]
impl Shuffler for IvfShuffler {
async fn shuffle(
&self,
data: Box<dyn RecordBatchStream + Unpin + 'static>,
) -> Result<Box<dyn ShuffleReader>> {
let num_partitions = self.num_partitions;
let mut partition_sizes = vec![0; num_partitions];
let schema = data.schema().without_column(PART_ID_COLUMN);
let mut writers = stream::iter(0..num_partitions)
.map(|partition_id| {
let part_path = self.output_dir.child(format!("ivf_{}.lance", partition_id));
let spill_path = self.output_dir.child(format!("ivf_{}.spill", partition_id));
let object_store = self.object_store.clone();
let schema = schema.clone();
let format_version = self.format_version;
async move {
let writer = object_store.create(&part_path).await?;
let file_writer = FileWriter::try_new(
writer,
lance_core::datatypes::Schema::try_from(&schema)?,
FileWriterOptions {
format_version: Some(format_version),
..Default::default()
},
)?
.with_page_metadata_spill(object_store.clone(), spill_path);
Result::Ok(file_writer)
}
})
.buffered(self.object_store.io_parallelism())
.try_collect::<Vec<_>>()
.await?;
let mut parallel_sort_stream = data
.map(|batch| {
spawn_cpu(move || {
let batch = batch?;
let loss = batch
.metadata()
.get(LOSS_METADATA_KEY)
.map(|s| s.parse::<f64>().unwrap_or_default())
.unwrap_or_default();
let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
let indices = sort_to_indices(&part_ids, None, None)?;
let batch = batch.take(&indices)?;
let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
let batch = batch.drop_column(PART_ID_COLUMN)?;
let mut partition_buffers = vec![Vec::new(); num_partitions];
let mut start = 0;
while start < batch.num_rows() {
let part_id: u32 = part_ids.value(start);
let mut end = start + 1;
while end < batch.num_rows() && part_ids.value(end) == part_id {
end += 1;
}
let part_batches = &mut partition_buffers[part_id as usize];
part_batches.push(batch.slice(start, end - start));
start = end;
}
Ok::<(Vec<Vec<RecordBatch>>, f64), Error>((partition_buffers, loss))
})
})
.buffered(get_num_compute_intensive_cpus());
let mut total_loss = 0.0;
let mut num_rows = 0u64;
while let Some(shuffled) = parallel_sort_stream.next().await {
let (shuffled, loss) = shuffled?;
total_loss += loss;
let mut futs = Vec::new();
for (part_id, (writer, batches)) in writers.iter_mut().zip(shuffled.iter()).enumerate()
{
if !batches.is_empty() {
let rows = batches.iter().map(|b| b.num_rows()).sum::<usize>();
partition_sizes[part_id] += rows;
num_rows += rows as u64;
futs.push(writer.write_batches(batches.iter()));
}
}
try_join_all(futs).await?;
self.progress.stage_progress("shuffle", num_rows).await?;
}
for writer in writers.iter_mut() {
writer.finish().await?;
}
Ok(Box::new(IvfShufflerReader::new(
self.object_store.clone(),
self.output_dir.clone(),
partition_sizes,
total_loss,
)))
}
}
pub struct IvfShufflerReader {
scheduler: Arc<ScanScheduler>,
output_dir: Path,
partition_sizes: Vec<usize>,
loss: f64,
}
impl IvfShufflerReader {
pub fn new(
object_store: Arc<ObjectStore>,
output_dir: Path,
partition_sizes: Vec<usize>,
loss: f64,
) -> Self {
let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
let scheduler = ScanScheduler::new(object_store, scheduler_config);
Self {
scheduler,
output_dir,
partition_sizes,
loss,
}
}
}
#[async_trait::async_trait]
impl ShuffleReader for IvfShufflerReader {
async fn read_partition(
&self,
partition_id: usize,
) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>> {
if partition_id >= self.partition_sizes.len() {
return Ok(None);
}
let partition_path = self.output_dir.child(format!("ivf_{}.lance", partition_id));
let reader = FileReader::try_open(
self.scheduler
.open_file(&partition_path, &CachedFileSize::unknown())
.await?,
None,
Arc::<DecoderPlugins>::default(),
&LanceCache::no_cache(),
FileReaderOptions::default(),
)
.await?;
let schema: Schema = reader.schema().as_ref().into();
Ok(Some(Box::new(RecordBatchStreamAdapter::new(
Arc::new(schema),
reader.read_stream(
lance_io::ReadBatchParams::RangeFull,
u32::MAX,
16,
FilterExpression::no_filter(),
)?,
))))
}
fn partition_size(&self, partition_id: usize) -> Result<usize> {
Ok(self.partition_sizes.get(partition_id).copied().unwrap_or(0))
}
fn total_loss(&self) -> Option<f64> {
Some(self.loss)
}
}
pub struct EmptyReader;
#[async_trait::async_trait]
impl ShuffleReader for EmptyReader {
async fn read_partition(
&self,
_partition_id: usize,
) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>> {
Ok(None)
}
fn partition_size(&self, _partition_id: usize) -> Result<usize> {
Ok(0)
}
fn total_loss(&self) -> Option<f64> {
None
}
}
pub fn create_ivf_shuffler(
output_dir: Path,
num_partitions: usize,
format_version: LanceFileVersion,
progress: Option<Arc<dyn crate::progress::IndexBuildProgress>>,
) -> Box<dyn Shuffler> {
let use_legacy = std::env::var("LANCE_LEGACY_SHUFFLER")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if use_legacy {
let mut shuffler =
IvfShuffler::new(output_dir, num_partitions).with_format_version(format_version);
if let Some(progress) = progress {
shuffler = shuffler.with_progress(progress);
}
Box::new(shuffler)
} else {
let mut shuffler = TwoFileShuffler::new(output_dir, num_partitions);
if let Some(progress) = progress {
shuffler = shuffler.with_progress(progress);
}
Box::new(shuffler)
}
}
const DEFAULT_SHUFFLE_BATCH_BYTES: usize = 128 * 1024 * 1024;
fn shuffle_batch_bytes() -> usize {
let batch_size = std::env::var("LANCE_SHUFFLE_BATCH_BYTES")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_SHUFFLE_BATCH_BYTES);
if batch_size == 0 {
log::warn!(
"LANCE_SHUFFLE_BATCH_BYTES is 0, using default of {}",
DEFAULT_SHUFFLE_BATCH_BYTES
);
DEFAULT_SHUFFLE_BATCH_BYTES
} else {
batch_size
}
}
pub struct TwoFileShuffler {
object_store: Arc<ObjectStore>,
output_dir: Path,
num_partitions: usize,
batch_size_bytes: usize,
progress: Arc<dyn crate::progress::IndexBuildProgress>,
}
impl TwoFileShuffler {
pub fn new(output_dir: Path, num_partitions: usize) -> Self {
Self {
object_store: Arc::new(ObjectStore::local()),
output_dir,
num_partitions,
batch_size_bytes: shuffle_batch_bytes(),
progress: crate::progress::noop_progress(),
}
}
pub fn with_progress(mut self, progress: Arc<dyn crate::progress::IndexBuildProgress>) -> Self {
self.progress = progress;
self
}
#[cfg(test)]
fn with_batch_size_bytes(mut self, batch_size_bytes: usize) -> Self {
self.batch_size_bytes = batch_size_bytes;
self
}
}
#[async_trait::async_trait]
impl Shuffler for TwoFileShuffler {
async fn shuffle(
&self,
data: Box<dyn RecordBatchStream + Unpin + 'static>,
) -> Result<Box<dyn ShuffleReader>> {
let num_partitions = self.num_partitions;
let full_schema = Arc::new(data.schema().as_ref().clone());
let schema = data.schema().without_column(PART_ID_COLUMN);
let offsets_schema = Arc::new(Schema::new(vec![Field::new(
"offset",
DataType::UInt64,
false,
)]));
let batch_size_bytes = self.batch_size_bytes;
let total_loss = Arc::new(Mutex::new(0.0f64));
let loss_ref = total_loss.clone();
let loss_stream = data.map(move |result| {
result.inspect(|batch| {
let loss = batch
.metadata()
.get(LOSS_METADATA_KEY)
.and_then(|s| s.parse::<f64>().ok())
.unwrap_or(0.0);
*loss_ref.lock().unwrap() += loss;
})
});
let rechunked = rechunk_stream_by_size(
loss_stream,
full_schema,
batch_size_bytes,
batch_size_bytes * 2,
);
let data_path = self.output_dir.child("shuffle_data.lance");
let spill_path = self.output_dir.child("shuffle_data.spill");
let writer = self.object_store.create(&data_path).await?;
let mut file_writer = FileWriter::try_new(
writer,
lance_core::datatypes::Schema::try_from(&schema)?,
Default::default(),
)?
.with_page_metadata_spill(self.object_store.clone(), spill_path);
let offsets_path = self.output_dir.child("shuffle_offsets.lance");
let spill_path = self.output_dir.child("shuffle_offsets.spill");
let writer = self.object_store.create(&offsets_path).await?;
let mut offsets_writer = FileWriter::try_new(
writer,
lance_core::datatypes::Schema::try_from(offsets_schema.as_ref())?,
Default::default(),
)?
.with_page_metadata_spill(self.object_store.clone(), spill_path);
let num_batches = Arc::new(AtomicU64::new(0));
let num_batches_ref = num_batches.clone();
let mut partition_counts: Vec<u64> = vec![0; num_partitions];
let mut global_row_count: u64 = 0;
let mut rows_processed: u64 = 0;
let mut rechunked = std::pin::pin!(rechunked);
while let Some(batch) = rechunked.next().await {
num_batches_ref.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let batch = batch?;
let np = num_partitions;
let num_rows = batch.num_rows() as u64;
let (sorted_batch, batch_offsets) = spawn_cpu(move || {
let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
let indices = sort_to_indices(part_ids, None, None)?;
let batch = batch.take(&indices)?;
let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
let batch = batch.drop_column(PART_ID_COLUMN)?;
let mut partition_counts = vec![0u64; np];
for i in 0..part_ids.len() {
let pid = part_ids.value(i) as usize;
if pid < np {
partition_counts[pid] += 1;
} else {
log::warn!("Partition ID {} is out of range [0, {})", pid, np);
}
}
let mut batch_offsets = Vec::with_capacity(np);
let mut running = 0u64;
for count in &partition_counts {
running += count;
batch_offsets.push(running);
}
Ok::<(RecordBatch, Vec<u64>), Error>((batch, batch_offsets))
})
.await?;
file_writer.write_batch(&sorted_batch).await?;
let mut adjusted_offsets = Vec::with_capacity(batch_offsets.len());
let mut last_offset = 0;
for (idx, offset) in batch_offsets.iter().enumerate() {
adjusted_offsets.push(global_row_count + offset);
partition_counts[idx] += offset - last_offset;
last_offset = *offset;
}
global_row_count += sorted_batch.num_rows() as u64;
let offsets_batch = RecordBatch::try_new(
offsets_schema.clone(),
vec![Arc::new(UInt64Array::from(adjusted_offsets))],
)?;
offsets_writer.write_batch(&offsets_batch).await?;
rows_processed += num_rows;
self.progress
.stage_progress("shuffle", rows_processed)
.await?;
}
file_writer.finish().await?;
offsets_writer.finish().await?;
let num_batches = num_batches.load(std::sync::atomic::Ordering::Relaxed);
let total_loss_val = *total_loss.lock().unwrap();
TwoFileShuffleReader::try_new(
self.object_store.clone(),
self.output_dir.clone(),
num_partitions,
num_batches,
partition_counts,
total_loss_val,
)
.await
}
}
pub struct TwoFileShuffleReader {
_scheduler: Arc<ScanScheduler>,
file_reader: FileReader,
offsets_reader: FileReader,
num_partitions: usize,
num_batches: u64,
partition_counts: Vec<u64>,
total_loss: f64,
}
impl TwoFileShuffleReader {
async fn try_new(
object_store: Arc<ObjectStore>,
output_dir: Path,
num_partitions: usize,
num_batches: u64,
partition_counts: Vec<u64>,
total_loss: f64,
) -> Result<Box<dyn ShuffleReader>> {
if num_batches == 0 {
return Ok(Box::new(EmptyReader));
}
let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
let scheduler = ScanScheduler::new(object_store, scheduler_config);
let data_path = output_dir.child("shuffle_data.lance");
let file_reader = FileReader::try_open(
scheduler
.open_file(&data_path, &CachedFileSize::unknown())
.await?,
None,
Arc::<DecoderPlugins>::default(),
&LanceCache::no_cache(),
FileReaderOptions::default(),
)
.await?;
let offsets_path = output_dir.child("shuffle_offsets.lance");
let offsets_reader = FileReader::try_open(
scheduler
.open_file(&offsets_path, &CachedFileSize::unknown())
.await?,
None,
Arc::<DecoderPlugins>::default(),
&LanceCache::no_cache(),
FileReaderOptions::default(),
)
.await?;
Ok(Box::new(Self {
_scheduler: scheduler,
file_reader,
offsets_reader,
num_partitions,
num_batches,
partition_counts,
total_loss,
}))
}
async fn partition_ranges(&self, partition_id: usize) -> Result<Vec<Range<u64>>> {
let mut positions = Vec::with_capacity(self.num_batches as usize * 2);
for batch_idx in 0..self.num_batches {
let end_pos = u32::try_from(batch_idx as usize * self.num_partitions + partition_id)
.map_err(|_| Error::invalid_input("There are more than 2^32 partition offsets in the spill file. Need to support 64-bit take"))?;
if end_pos != 0 {
positions.push(end_pos - 1);
}
positions.push(end_pos);
}
let positions = UInt32Array::from(positions);
let num_positions = positions.len() as u32;
let offsets_stream = self.offsets_reader.read_stream(
ReadBatchParams::Indices(positions),
num_positions,
1,
FilterExpression::no_filter(),
)?;
let schema = offsets_stream.schema().clone();
let offsets = offsets_stream.try_collect::<Vec<_>>().await?;
let offsets = if offsets.is_empty() {
unreachable!()
} else if offsets.len() == 1 {
offsets.into_iter().next().unwrap()
} else {
concat_batches(&schema, &offsets)?
};
let offsets = offsets.column(0).as_primitive::<UInt64Type>();
let mut offsets_iter = offsets.values().iter().copied();
let mut ranges = Vec::with_capacity(self.num_batches as usize);
for batch_idx in 0..self.num_batches {
if batch_idx == 0 && partition_id == 0 {
ranges.push(0..offsets_iter.next().unwrap());
} else {
ranges.push(offsets_iter.next().unwrap()..offsets_iter.next().unwrap());
}
}
Ok(ranges)
}
}
#[async_trait::async_trait]
impl ShuffleReader for TwoFileShuffleReader {
async fn read_partition(
&self,
partition_id: usize,
) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>> {
if partition_id >= self.num_partitions {
return Ok(None);
}
if self.partition_counts[partition_id] == 0 {
return Ok(None);
}
let ranges = self.partition_ranges(partition_id).await?;
if ranges.is_empty() {
return Ok(None);
}
let schema: Schema = self.file_reader.schema().as_ref().into();
Ok(Some(Box::new(RecordBatchStreamAdapter::new(
Arc::new(schema),
self.file_reader.read_stream(
ReadBatchParams::Ranges(ranges.into()),
u32::MAX,
16,
FilterExpression::no_filter(),
)?,
))))
}
fn partition_size(&self, partition_id: usize) -> Result<usize> {
Ok(self
.partition_counts
.get(partition_id)
.copied()
.unwrap_or(0) as usize)
}
fn total_loss(&self) -> Option<f64> {
Some(self.total_loss)
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int32Array, RecordBatch, UInt32Array};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use futures::stream;
use lance_arrow::RecordBatchExt;
use lance_core::utils::tempfile::TempStrDir;
use lance_io::stream::RecordBatchStreamAdapter;
use crate::vector::{LOSS_METADATA_KEY, PART_ID_COLUMN};
fn make_batch(part_ids: &[u32], values: &[i32], loss: Option<f64>) -> RecordBatch {
let schema = Arc::new(ArrowSchema::new(vec![
Field::new(PART_ID_COLUMN, DataType::UInt32, false),
Field::new("val", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(UInt32Array::from(part_ids.to_vec())),
Arc::new(Int32Array::from(values.to_vec())),
],
)
.unwrap();
if let Some(loss_val) = loss {
batch
.add_metadata(LOSS_METADATA_KEY.to_owned(), loss_val.to_string())
.unwrap()
} else {
batch
}
}
fn batches_to_stream(
batches: Vec<RecordBatch>,
) -> Box<dyn RecordBatchStream + Unpin + 'static> {
let schema = batches[0].schema();
let stream = stream::iter(batches.into_iter().map(Ok));
Box::new(RecordBatchStreamAdapter::new(schema, stream))
}
async fn collect_partition(
reader: &dyn ShuffleReader,
partition_id: usize,
) -> Option<RecordBatch> {
let stream = reader.read_partition(partition_id).await.unwrap()?;
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
if batches.is_empty() {
return None;
}
Some(arrow::compute::concat_batches(&batches[0].schema(), &batches).unwrap())
}
#[tokio::test]
async fn test_two_file_shuffler_round_trip() {
let dir = TempStrDir::default();
let output_dir = Path::from(dir.as_ref());
let num_partitions = 3;
let batch = make_batch(&[0, 1, 2, 0, 1], &[10, 20, 30, 40, 50], None);
let shuffler = TwoFileShuffler::new(output_dir, num_partitions);
let stream = batches_to_stream(vec![batch]);
let reader = shuffler.shuffle(stream).await.unwrap();
assert_eq!(reader.partition_size(0).unwrap(), 2);
assert_eq!(reader.partition_size(1).unwrap(), 2);
assert_eq!(reader.partition_size(2).unwrap(), 1);
let p0 = collect_partition(reader.as_ref(), 0).await.unwrap();
let vals: &Int32Array = p0.column_by_name("val").unwrap().as_primitive();
let mut v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
v.sort();
assert_eq!(v, vec![10, 40]);
let p1 = collect_partition(reader.as_ref(), 1).await.unwrap();
let vals: &Int32Array = p1.column_by_name("val").unwrap().as_primitive();
let mut v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
v.sort();
assert_eq!(v, vec![20, 50]);
let p2 = collect_partition(reader.as_ref(), 2).await.unwrap();
let vals: &Int32Array = p2.column_by_name("val").unwrap().as_primitive();
let v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
assert_eq!(v, vec![30]);
assert!(reader.read_partition(3).await.unwrap().is_none());
}
#[tokio::test]
async fn test_two_file_shuffler_empty_partitions() {
let dir = TempStrDir::default();
let output_dir = Path::from(dir.as_ref());
let num_partitions = 5;
let batch = make_batch(&[0, 3, 0, 3], &[10, 20, 30, 40], None);
let shuffler = TwoFileShuffler::new(output_dir, num_partitions);
let stream = batches_to_stream(vec![batch]);
let reader = shuffler.shuffle(stream).await.unwrap();
assert_eq!(reader.partition_size(0).unwrap(), 2);
assert_eq!(reader.partition_size(1).unwrap(), 0);
assert_eq!(reader.partition_size(2).unwrap(), 0);
assert_eq!(reader.partition_size(3).unwrap(), 2);
assert_eq!(reader.partition_size(4).unwrap(), 0);
assert!(reader.read_partition(1).await.unwrap().is_none());
assert!(reader.read_partition(2).await.unwrap().is_none());
assert!(reader.read_partition(4).await.unwrap().is_none());
let p0 = collect_partition(reader.as_ref(), 0).await.unwrap();
assert_eq!(p0.num_rows(), 2);
let p3 = collect_partition(reader.as_ref(), 3).await.unwrap();
assert_eq!(p3.num_rows(), 2);
}
#[tokio::test]
async fn test_two_file_shuffler_loss_tracking() {
let dir = TempStrDir::default();
let output_dir = Path::from(dir.as_ref());
let num_partitions = 2;
let batch1 = make_batch(&[0, 1], &[10, 20], Some(1.5));
let batch2 = make_batch(&[0, 1], &[30, 40], Some(2.5));
let batch3 = make_batch(&[0], &[50], Some(0.25));
let shuffler = TwoFileShuffler::new(output_dir, num_partitions);
let stream = batches_to_stream(vec![batch1, batch2, batch3]);
let reader = shuffler.shuffle(stream).await.unwrap();
let loss = reader.total_loss().unwrap();
assert!((loss - 4.25).abs() < 1e-10, "expected 4.25, got {}", loss);
}
#[tokio::test]
async fn test_two_file_shuffler_single_batch() {
let dir = TempStrDir::default();
let output_dir = Path::from(dir.as_ref());
let num_partitions = 2;
let batch = make_batch(&[1, 0], &[100, 200], Some(3.0));
let shuffler = TwoFileShuffler::new(output_dir, num_partitions);
let stream = batches_to_stream(vec![batch]);
let reader = shuffler.shuffle(stream).await.unwrap();
assert_eq!(reader.partition_size(0).unwrap(), 1);
assert_eq!(reader.partition_size(1).unwrap(), 1);
let p0 = collect_partition(reader.as_ref(), 0).await.unwrap();
let vals: &Int32Array = p0.column_by_name("val").unwrap().as_primitive();
assert_eq!(vals.value(0), 200);
let p1 = collect_partition(reader.as_ref(), 1).await.unwrap();
let vals: &Int32Array = p1.column_by_name("val").unwrap().as_primitive();
assert_eq!(vals.value(0), 100);
assert!((reader.total_loss().unwrap() - 3.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_two_file_shuffler_multiple_batches() {
let dir = TempStrDir::default();
let output_dir = Path::from(dir.as_ref());
let num_partitions = 3;
let batch1 = make_batch(&[0, 1, 2], &[10, 20, 30], Some(1.0));
let batch2 = make_batch(&[2, 0, 1], &[40, 50, 60], Some(2.0));
let batch3 = make_batch(&[1, 2, 0], &[70, 80, 90], Some(3.0));
let shuffler = TwoFileShuffler::new(output_dir, num_partitions)
.with_batch_size_bytes(16);
let stream = batches_to_stream(vec![batch1, batch2, batch3]);
let reader = shuffler.shuffle(stream).await.unwrap();
assert_eq!(reader.partition_size(0).unwrap(), 3);
let p0 = collect_partition(reader.as_ref(), 0).await.unwrap();
let vals: &Int32Array = p0.column_by_name("val").unwrap().as_primitive();
let mut v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
v.sort();
assert_eq!(v, vec![10, 50, 90]);
assert_eq!(reader.partition_size(1).unwrap(), 3);
let p1 = collect_partition(reader.as_ref(), 1).await.unwrap();
let vals: &Int32Array = p1.column_by_name("val").unwrap().as_primitive();
let mut v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
v.sort();
assert_eq!(v, vec![20, 60, 70]);
assert_eq!(reader.partition_size(2).unwrap(), 3);
let p2 = collect_partition(reader.as_ref(), 2).await.unwrap();
let vals: &Int32Array = p2.column_by_name("val").unwrap().as_primitive();
let mut v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
v.sort();
assert_eq!(v, vec![30, 40, 80]);
assert!((reader.total_loss().unwrap() - 6.0).abs() < 1e-10);
}
}