1use std::ops::Range;
8use std::sync::atomic::AtomicU64;
9use std::sync::{Arc, Mutex};
10
11use arrow::compute::concat_batches;
12use arrow::datatypes::UInt64Type;
13use arrow::{array::AsArray, compute::sort_to_indices};
14use arrow_array::{RecordBatch, UInt32Array, UInt64Array};
15use arrow_schema::{DataType, Field, Schema};
16use futures::{future::try_join_all, prelude::*};
17use lance_arrow::stream::rechunk_stream_by_size;
18use lance_arrow::{RecordBatchExt, SchemaExt};
19use lance_core::{
20 Error, Result,
21 cache::LanceCache,
22 utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu},
23};
24use lance_encoding::decoder::{DecoderPlugins, FilterExpression};
25use lance_encoding::version::LanceFileVersion;
26use lance_file::reader::{FileReader, FileReaderOptions};
27use lance_file::writer::{FileWriter, FileWriterOptions};
28use lance_io::{
29 ReadBatchParams,
30 object_store::ObjectStore,
31 scheduler::{ScanScheduler, SchedulerConfig},
32 stream::{RecordBatchStream, RecordBatchStreamAdapter},
33 utils::CachedFileSize,
34};
35use object_store::path::Path;
36
37use crate::vector::{LOSS_METADATA_KEY, PART_ID_COLUMN};
38
39#[async_trait::async_trait]
40pub trait ShuffleReader: Send + Sync {
42 async fn read_partition(
46 &self,
47 partition_id: usize,
48 ) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>>;
49
50 fn partition_size(&self, partition_id: usize) -> Result<usize>;
52
53 fn total_loss(&self) -> Option<f64>;
58}
59
60#[async_trait::async_trait]
61pub trait Shuffler: Send + Sync {
64 async fn shuffle(
67 &self,
68 data: Box<dyn RecordBatchStream + Unpin + 'static>,
69 ) -> Result<Box<dyn ShuffleReader>>;
70}
71
72pub struct IvfShuffler {
73 object_store: Arc<ObjectStore>,
74 output_dir: Path,
75 num_partitions: usize,
76 format_version: LanceFileVersion,
77
78 progress: Arc<dyn crate::progress::IndexBuildProgress>,
79}
80
81impl IvfShuffler {
82 pub fn new(output_dir: Path, num_partitions: usize) -> Self {
83 Self {
84 object_store: Arc::new(ObjectStore::local()),
85 output_dir,
86 num_partitions,
87 format_version: LanceFileVersion::V2_0,
88 progress: crate::progress::noop_progress(),
89 }
90 }
91
92 pub fn with_format_version(mut self, format_version: LanceFileVersion) -> Self {
93 self.format_version = format_version;
94 self
95 }
96
97 pub fn with_progress(mut self, progress: Arc<dyn crate::progress::IndexBuildProgress>) -> Self {
98 self.progress = progress;
99 self
100 }
101}
102
103#[async_trait::async_trait]
104impl Shuffler for IvfShuffler {
105 async fn shuffle(
106 &self,
107 data: Box<dyn RecordBatchStream + Unpin + 'static>,
108 ) -> Result<Box<dyn ShuffleReader>> {
109 let num_partitions = self.num_partitions;
110 let mut partition_sizes = vec![0; num_partitions];
111 let schema = data.schema().without_column(PART_ID_COLUMN);
112 let mut writers = stream::iter(0..num_partitions)
113 .map(|partition_id| {
114 let part_path = self
115 .output_dir
116 .clone()
117 .join(format!("ivf_{}.lance", partition_id));
118 let spill_path = self
119 .output_dir
120 .clone()
121 .join(format!("ivf_{}.spill", partition_id));
122 let object_store = self.object_store.clone();
123 let schema = schema.clone();
124 let format_version = self.format_version;
125 async move {
126 let writer = object_store.create(&part_path).await?;
127 let file_writer = FileWriter::try_new(
128 writer,
129 lance_core::datatypes::Schema::try_from(&schema)?,
130 FileWriterOptions {
131 format_version: Some(format_version),
132 ..Default::default()
133 },
134 )?
135 .with_page_metadata_spill(object_store.clone(), spill_path);
136 Result::Ok(file_writer)
137 }
138 })
139 .buffered(self.object_store.io_parallelism())
140 .try_collect::<Vec<_>>()
141 .await?;
142 let mut parallel_sort_stream = data
143 .map(|batch| {
144 spawn_cpu(move || {
145 let batch = batch?;
146
147 let loss = batch
148 .metadata()
149 .get(LOSS_METADATA_KEY)
150 .map(|s| s.parse::<f64>().unwrap_or_default())
151 .unwrap_or_default();
152
153 let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
154
155 let indices = sort_to_indices(&part_ids, None, None)?;
156 let batch = batch.take(&indices)?;
157
158 let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
159 let batch = batch.drop_column(PART_ID_COLUMN)?;
160
161 let mut partition_buffers = vec![Vec::new(); num_partitions];
162
163 let mut start = 0;
164 while start < batch.num_rows() {
165 let part_id: u32 = part_ids.value(start);
166 let mut end = start + 1;
167 while end < batch.num_rows() && part_ids.value(end) == part_id {
168 end += 1;
169 }
170
171 let part_batches = &mut partition_buffers[part_id as usize];
172 part_batches.push(batch.slice(start, end - start));
173 start = end;
174 }
175
176 Ok::<(Vec<Vec<RecordBatch>>, f64), Error>((partition_buffers, loss))
177 })
178 })
179 .buffered(get_num_compute_intensive_cpus());
180
181 let mut total_loss = 0.0;
182 let mut num_rows = 0u64;
183 while let Some(shuffled) = parallel_sort_stream.next().await {
184 let (shuffled, loss) = shuffled?;
185 total_loss += loss;
186
187 let mut futs = Vec::new();
188 for (part_id, (writer, batches)) in writers.iter_mut().zip(shuffled.iter()).enumerate()
189 {
190 if !batches.is_empty() {
191 let rows = batches.iter().map(|b| b.num_rows()).sum::<usize>();
192 partition_sizes[part_id] += rows;
193 num_rows += rows as u64;
194 futs.push(writer.write_batches(batches.iter()));
195 }
196 }
197 try_join_all(futs).await?;
198
199 self.progress.stage_progress("shuffle", num_rows).await?;
200 }
201
202 for writer in writers.iter_mut() {
204 writer.finish().await?;
205 }
206
207 Ok(Box::new(IvfShufflerReader::new(
208 self.object_store.clone(),
209 self.output_dir.clone(),
210 partition_sizes,
211 total_loss,
212 )))
213 }
214}
215
216pub struct IvfShufflerReader {
217 scheduler: Arc<ScanScheduler>,
218 output_dir: Path,
219 partition_sizes: Vec<usize>,
220 loss: f64,
221}
222
223impl IvfShufflerReader {
224 pub fn new(
225 object_store: Arc<ObjectStore>,
226 output_dir: Path,
227 partition_sizes: Vec<usize>,
228 loss: f64,
229 ) -> Self {
230 let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
231 let scheduler = ScanScheduler::new(object_store, scheduler_config);
232 Self {
233 scheduler,
234 output_dir,
235 partition_sizes,
236 loss,
237 }
238 }
239}
240
241#[async_trait::async_trait]
242impl ShuffleReader for IvfShufflerReader {
243 async fn read_partition(
244 &self,
245 partition_id: usize,
246 ) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>> {
247 if partition_id >= self.partition_sizes.len() {
248 return Ok(None);
249 }
250
251 let partition_path = self
252 .output_dir
253 .clone()
254 .join(format!("ivf_{}.lance", partition_id));
255
256 let reader = FileReader::try_open(
257 self.scheduler
258 .open_file(&partition_path, &CachedFileSize::unknown())
259 .await?,
260 None,
261 Arc::<DecoderPlugins>::default(),
262 &LanceCache::no_cache(),
263 FileReaderOptions::default(),
264 )
265 .await?;
266 let schema: Schema = reader.schema().as_ref().into();
267 let stream = reader
268 .read_stream(
269 lance_io::ReadBatchParams::RangeFull,
270 u32::MAX,
271 16,
272 FilterExpression::no_filter(),
273 )
274 .await?;
275 Ok(Some(Box::new(RecordBatchStreamAdapter::new(
276 Arc::new(schema),
277 stream,
278 ))))
279 }
280
281 fn partition_size(&self, partition_id: usize) -> Result<usize> {
282 Ok(self.partition_sizes.get(partition_id).copied().unwrap_or(0))
283 }
284
285 fn total_loss(&self) -> Option<f64> {
286 Some(self.loss)
287 }
288}
289
290pub struct EmptyReader;
291
292#[async_trait::async_trait]
293impl ShuffleReader for EmptyReader {
294 async fn read_partition(
295 &self,
296 _partition_id: usize,
297 ) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>> {
298 Ok(None)
299 }
300
301 fn partition_size(&self, _partition_id: usize) -> Result<usize> {
302 Ok(0)
303 }
304
305 fn total_loss(&self) -> Option<f64> {
306 None
307 }
308}
309
310pub fn create_ivf_shuffler(
318 output_dir: Path,
319 num_partitions: usize,
320 format_version: LanceFileVersion,
321 progress: Option<Arc<dyn crate::progress::IndexBuildProgress>>,
322) -> Box<dyn Shuffler> {
323 let use_legacy = std::env::var("LANCE_LEGACY_SHUFFLER")
324 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
325 .unwrap_or(false);
326 if use_legacy {
327 let mut shuffler =
328 IvfShuffler::new(output_dir, num_partitions).with_format_version(format_version);
329 if let Some(progress) = progress {
330 shuffler = shuffler.with_progress(progress);
331 }
332 Box::new(shuffler)
333 } else {
334 let mut shuffler = TwoFileShuffler::new(output_dir, num_partitions);
335 if let Some(progress) = progress {
336 shuffler = shuffler.with_progress(progress);
337 }
338 Box::new(shuffler)
339 }
340}
341
342const DEFAULT_SHUFFLE_BATCH_BYTES: usize = 128 * 1024 * 1024;
343
344fn shuffle_batch_bytes() -> usize {
351 let batch_size = std::env::var("LANCE_SHUFFLE_BATCH_BYTES")
352 .ok()
353 .and_then(|s| s.parse().ok())
354 .unwrap_or(DEFAULT_SHUFFLE_BATCH_BYTES);
355 if batch_size == 0 {
356 log::warn!(
357 "LANCE_SHUFFLE_BATCH_BYTES is 0, using default of {}",
358 DEFAULT_SHUFFLE_BATCH_BYTES
359 );
360 DEFAULT_SHUFFLE_BATCH_BYTES
361 } else {
362 batch_size
363 }
364}
365
366pub struct TwoFileShuffler {
379 object_store: Arc<ObjectStore>,
380 output_dir: Path,
381 num_partitions: usize,
382 batch_size_bytes: usize,
383
384 progress: Arc<dyn crate::progress::IndexBuildProgress>,
385}
386
387impl TwoFileShuffler {
388 pub fn new(output_dir: Path, num_partitions: usize) -> Self {
389 Self {
390 object_store: Arc::new(ObjectStore::local()),
391 output_dir,
392 num_partitions,
393 batch_size_bytes: shuffle_batch_bytes(),
394 progress: crate::progress::noop_progress(),
395 }
396 }
397
398 pub fn with_progress(mut self, progress: Arc<dyn crate::progress::IndexBuildProgress>) -> Self {
399 self.progress = progress;
400 self
401 }
402
403 #[cfg(test)]
404 fn with_batch_size_bytes(mut self, batch_size_bytes: usize) -> Self {
405 self.batch_size_bytes = batch_size_bytes;
406 self
407 }
408}
409
410#[async_trait::async_trait]
411impl Shuffler for TwoFileShuffler {
412 async fn shuffle(
413 &self,
414 data: Box<dyn RecordBatchStream + Unpin + 'static>,
415 ) -> Result<Box<dyn ShuffleReader>> {
416 let num_partitions = self.num_partitions;
417 let full_schema = Arc::new(data.schema().as_ref().clone());
418 let schema = data.schema().without_column(PART_ID_COLUMN);
420 let offsets_schema = Arc::new(Schema::new(vec![Field::new(
421 "offset",
422 DataType::UInt64,
423 false,
424 )]));
425 let batch_size_bytes = self.batch_size_bytes;
426
427 let total_loss = Arc::new(Mutex::new(0.0f64));
429 let loss_ref = total_loss.clone();
430 let loss_stream = data.map(move |result| {
431 result.inspect(|batch| {
432 let loss = batch
433 .metadata()
434 .get(LOSS_METADATA_KEY)
435 .and_then(|s| s.parse::<f64>().ok())
436 .unwrap_or(0.0);
437 *loss_ref.lock().unwrap() += loss;
438 })
439 });
440
441 let rechunked = rechunk_stream_by_size(
443 loss_stream,
444 full_schema,
445 batch_size_bytes,
446 batch_size_bytes * 2,
447 );
448
449 let data_path = self.output_dir.clone().join("shuffle_data.lance");
451 let spill_path = self.output_dir.clone().join("shuffle_data.spill");
452 let writer = self.object_store.create(&data_path).await?;
453 let mut file_writer = FileWriter::try_new(
454 writer,
455 lance_core::datatypes::Schema::try_from(&schema)?,
456 Default::default(),
457 )?
458 .with_page_metadata_spill(self.object_store.clone(), spill_path);
459
460 let offsets_path = self.output_dir.clone().join("shuffle_offsets.lance");
462 let spill_path = self.output_dir.clone().join("shuffle_offsets.spill");
463 let writer = self.object_store.create(&offsets_path).await?;
464 let mut offsets_writer = FileWriter::try_new(
465 writer,
466 lance_core::datatypes::Schema::try_from(offsets_schema.as_ref())?,
467 Default::default(),
468 )?
469 .with_page_metadata_spill(self.object_store.clone(), spill_path);
470
471 let num_batches = Arc::new(AtomicU64::new(0));
472 let num_batches_ref = num_batches.clone();
473 let mut partition_counts: Vec<u64> = vec![0; num_partitions];
474 let mut global_row_count: u64 = 0;
475 let mut rows_processed: u64 = 0;
476
477 let mut rechunked = std::pin::pin!(rechunked);
478 while let Some(batch) = rechunked.next().await {
479 num_batches_ref.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
480 let batch = batch?;
481 let np = num_partitions;
482 let num_rows = batch.num_rows() as u64;
483
484 let (sorted_batch, batch_offsets) = spawn_cpu(move || {
486 let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
487 let indices = sort_to_indices(part_ids, None, None)?;
488 let batch = batch.take(&indices)?;
489
490 let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
491 let batch = batch.drop_column(PART_ID_COLUMN)?;
492
493 let mut partition_counts = vec![0u64; np];
495 for i in 0..part_ids.len() {
496 let pid = part_ids.value(i) as usize;
497 if pid < np {
498 partition_counts[pid] += 1;
499 } else {
500 log::warn!("Partition ID {} is out of range [0, {})", pid, np);
501 }
502 }
503
504 let mut batch_offsets = Vec::with_capacity(np);
506 let mut running = 0u64;
507 for count in &partition_counts {
508 running += count;
509 batch_offsets.push(running);
510 }
511
512 Ok::<(RecordBatch, Vec<u64>), Error>((batch, batch_offsets))
513 })
514 .await?;
515
516 file_writer.write_batch(&sorted_batch).await?;
518
519 let mut adjusted_offsets = Vec::with_capacity(batch_offsets.len());
521 let mut last_offset = 0;
522 for (idx, offset) in batch_offsets.iter().enumerate() {
523 adjusted_offsets.push(global_row_count + offset);
524 partition_counts[idx] += offset - last_offset;
525 last_offset = *offset;
526 }
527 global_row_count += sorted_batch.num_rows() as u64;
528
529 let offsets_batch = RecordBatch::try_new(
531 offsets_schema.clone(),
532 vec![Arc::new(UInt64Array::from(adjusted_offsets))],
533 )?;
534 offsets_writer.write_batch(&offsets_batch).await?;
535
536 rows_processed += num_rows;
537 self.progress
538 .stage_progress("shuffle", rows_processed)
539 .await?;
540 }
541
542 file_writer.finish().await?;
544 offsets_writer.finish().await?;
545
546 let num_batches = num_batches.load(std::sync::atomic::Ordering::Relaxed);
547
548 let total_loss_val = *total_loss.lock().unwrap();
549
550 TwoFileShuffleReader::try_new(
551 self.object_store.clone(),
552 self.output_dir.clone(),
553 num_partitions,
554 num_batches,
555 partition_counts,
556 total_loss_val,
557 )
558 .await
559 }
560}
561
562pub struct TwoFileShuffleReader {
563 _scheduler: Arc<ScanScheduler>,
564 file_reader: FileReader,
565 offsets_reader: FileReader,
566 num_partitions: usize,
567 num_batches: u64,
568 partition_counts: Vec<u64>,
569 total_loss: f64,
570}
571
572impl TwoFileShuffleReader {
573 async fn try_new(
574 object_store: Arc<ObjectStore>,
575 output_dir: Path,
576 num_partitions: usize,
577 num_batches: u64,
578 partition_counts: Vec<u64>,
579 total_loss: f64,
580 ) -> Result<Box<dyn ShuffleReader>> {
581 if num_batches == 0 {
582 return Ok(Box::new(EmptyReader));
583 }
584
585 let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
586 let scheduler = ScanScheduler::new(object_store, scheduler_config);
587
588 let data_path = output_dir.clone().join("shuffle_data.lance");
589 let file_reader = FileReader::try_open(
590 scheduler
591 .open_file(&data_path, &CachedFileSize::unknown())
592 .await?,
593 None,
594 Arc::<DecoderPlugins>::default(),
595 &LanceCache::no_cache(),
596 FileReaderOptions::default(),
597 )
598 .await?;
599
600 let offsets_path = output_dir.clone().join("shuffle_offsets.lance");
601 let offsets_reader = FileReader::try_open(
602 scheduler
603 .open_file(&offsets_path, &CachedFileSize::unknown())
604 .await?,
605 None,
606 Arc::<DecoderPlugins>::default(),
607 &LanceCache::no_cache(),
608 FileReaderOptions::default(),
609 )
610 .await?;
611
612 Ok(Box::new(Self {
613 _scheduler: scheduler,
614 file_reader,
615 offsets_reader,
616 num_partitions,
617 num_batches,
618 partition_counts,
619 total_loss,
620 }))
621 }
622
623 async fn partition_ranges(&self, partition_id: usize) -> Result<Vec<Range<u64>>> {
624 let mut positions = Vec::with_capacity(self.num_batches as usize * 2);
625 for batch_idx in 0..self.num_batches {
626 let end_pos = u32::try_from(batch_idx as usize * self.num_partitions + partition_id)
627 .map_err(|_| Error::invalid_input("There are more than 2^32 partition offsets in the spill file. Need to support 64-bit take"))?;
628 if end_pos != 0 {
629 positions.push(end_pos - 1);
630 }
631 positions.push(end_pos);
632 }
633 let positions = UInt32Array::from(positions);
634 let num_positions = positions.len() as u32;
635 let offsets_stream = self
636 .offsets_reader
637 .read_stream(
638 ReadBatchParams::Indices(positions),
639 num_positions,
640 1,
641 FilterExpression::no_filter(),
642 )
643 .await?;
644 let schema = offsets_stream.schema().clone();
645 let offsets = offsets_stream.try_collect::<Vec<_>>().await?;
646 let offsets = if offsets.is_empty() {
647 unreachable!()
649 } else if offsets.len() == 1 {
650 offsets.into_iter().next().unwrap()
651 } else {
652 concat_batches(&schema, &offsets)?
653 };
654
655 let offsets = offsets.column(0).as_primitive::<UInt64Type>();
656 let mut offsets_iter = offsets.values().iter().copied();
657
658 let mut ranges = Vec::with_capacity(self.num_batches as usize);
659 for batch_idx in 0..self.num_batches {
660 if batch_idx == 0 && partition_id == 0 {
661 ranges.push(0..offsets_iter.next().unwrap());
663 } else {
664 ranges.push(offsets_iter.next().unwrap()..offsets_iter.next().unwrap());
665 }
666 }
667 Ok(ranges)
668 }
669}
670
671#[async_trait::async_trait]
672impl ShuffleReader for TwoFileShuffleReader {
673 async fn read_partition(
674 &self,
675 partition_id: usize,
676 ) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>> {
677 if partition_id >= self.num_partitions {
678 return Ok(None);
679 }
680 if self.partition_counts[partition_id] == 0 {
681 return Ok(None);
682 }
683
684 let ranges = self.partition_ranges(partition_id).await?;
685 if ranges.is_empty() {
686 return Ok(None);
687 }
688
689 let schema: Schema = self.file_reader.schema().as_ref().into();
690 let stream = self
691 .file_reader
692 .read_stream(
693 ReadBatchParams::Ranges(ranges.into()),
694 u32::MAX,
695 16,
696 FilterExpression::no_filter(),
697 )
698 .await?;
699 Ok(Some(Box::new(RecordBatchStreamAdapter::new(
700 Arc::new(schema),
701 stream,
702 ))))
703 }
704
705 fn partition_size(&self, partition_id: usize) -> Result<usize> {
706 Ok(self
707 .partition_counts
708 .get(partition_id)
709 .copied()
710 .unwrap_or(0) as usize)
711 }
712
713 fn total_loss(&self) -> Option<f64> {
714 Some(self.total_loss)
715 }
716}
717
718#[cfg(test)]
719mod tests {
720 use super::*;
721
722 use arrow_array::{Int32Array, RecordBatch, UInt32Array};
723 use arrow_schema::{DataType, Field, Schema as ArrowSchema};
724 use futures::stream;
725 use lance_arrow::RecordBatchExt;
726 use lance_core::utils::tempfile::TempStrDir;
727 use lance_io::stream::RecordBatchStreamAdapter;
728
729 use crate::vector::{LOSS_METADATA_KEY, PART_ID_COLUMN};
730
731 fn make_batch(part_ids: &[u32], values: &[i32], loss: Option<f64>) -> RecordBatch {
733 let schema = Arc::new(ArrowSchema::new(vec![
734 Field::new(PART_ID_COLUMN, DataType::UInt32, false),
735 Field::new("val", DataType::Int32, false),
736 ]));
737 let batch = RecordBatch::try_new(
738 schema,
739 vec![
740 Arc::new(UInt32Array::from(part_ids.to_vec())),
741 Arc::new(Int32Array::from(values.to_vec())),
742 ],
743 )
744 .unwrap();
745 if let Some(loss_val) = loss {
746 batch
747 .add_metadata(LOSS_METADATA_KEY.to_owned(), loss_val.to_string())
748 .unwrap()
749 } else {
750 batch
751 }
752 }
753
754 fn batches_to_stream(
755 batches: Vec<RecordBatch>,
756 ) -> Box<dyn RecordBatchStream + Unpin + 'static> {
757 let schema = batches[0].schema();
758 let stream = stream::iter(batches.into_iter().map(Ok));
759 Box::new(RecordBatchStreamAdapter::new(schema, stream))
760 }
761
762 async fn collect_partition(
764 reader: &dyn ShuffleReader,
765 partition_id: usize,
766 ) -> Option<RecordBatch> {
767 let stream = reader.read_partition(partition_id).await.unwrap()?;
768 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
769 if batches.is_empty() {
770 return None;
771 }
772 Some(arrow::compute::concat_batches(&batches[0].schema(), &batches).unwrap())
773 }
774
775 #[tokio::test]
776 async fn test_two_file_shuffler_round_trip() {
777 let dir = TempStrDir::default();
778 let output_dir = Path::from(dir.as_ref());
779 let num_partitions = 3;
780
781 let batch = make_batch(&[0, 1, 2, 0, 1], &[10, 20, 30, 40, 50], None);
785
786 let shuffler = TwoFileShuffler::new(output_dir, num_partitions);
787 let stream = batches_to_stream(vec![batch]);
788 let reader = shuffler.shuffle(stream).await.unwrap();
789
790 assert_eq!(reader.partition_size(0).unwrap(), 2);
792 assert_eq!(reader.partition_size(1).unwrap(), 2);
793 assert_eq!(reader.partition_size(2).unwrap(), 1);
794
795 let p0 = collect_partition(reader.as_ref(), 0).await.unwrap();
797 let vals: &Int32Array = p0.column_by_name("val").unwrap().as_primitive();
798 let mut v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
799 v.sort();
800 assert_eq!(v, vec![10, 40]);
801
802 let p1 = collect_partition(reader.as_ref(), 1).await.unwrap();
804 let vals: &Int32Array = p1.column_by_name("val").unwrap().as_primitive();
805 let mut v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
806 v.sort();
807 assert_eq!(v, vec![20, 50]);
808
809 let p2 = collect_partition(reader.as_ref(), 2).await.unwrap();
811 let vals: &Int32Array = p2.column_by_name("val").unwrap().as_primitive();
812 let v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
813 assert_eq!(v, vec![30]);
814
815 assert!(reader.read_partition(3).await.unwrap().is_none());
817 }
818
819 #[tokio::test]
820 async fn test_two_file_shuffler_empty_partitions() {
821 let dir = TempStrDir::default();
822 let output_dir = Path::from(dir.as_ref());
823 let num_partitions = 5;
824
825 let batch = make_batch(&[0, 3, 0, 3], &[10, 20, 30, 40], None);
827
828 let shuffler = TwoFileShuffler::new(output_dir, num_partitions);
829 let stream = batches_to_stream(vec![batch]);
830 let reader = shuffler.shuffle(stream).await.unwrap();
831
832 assert_eq!(reader.partition_size(0).unwrap(), 2);
833 assert_eq!(reader.partition_size(1).unwrap(), 0);
834 assert_eq!(reader.partition_size(2).unwrap(), 0);
835 assert_eq!(reader.partition_size(3).unwrap(), 2);
836 assert_eq!(reader.partition_size(4).unwrap(), 0);
837
838 assert!(reader.read_partition(1).await.unwrap().is_none());
839 assert!(reader.read_partition(2).await.unwrap().is_none());
840 assert!(reader.read_partition(4).await.unwrap().is_none());
841
842 let p0 = collect_partition(reader.as_ref(), 0).await.unwrap();
843 assert_eq!(p0.num_rows(), 2);
844 let p3 = collect_partition(reader.as_ref(), 3).await.unwrap();
845 assert_eq!(p3.num_rows(), 2);
846 }
847
848 #[tokio::test]
849 async fn test_two_file_shuffler_loss_tracking() {
850 let dir = TempStrDir::default();
851 let output_dir = Path::from(dir.as_ref());
852 let num_partitions = 2;
853
854 let batch1 = make_batch(&[0, 1], &[10, 20], Some(1.5));
855 let batch2 = make_batch(&[0, 1], &[30, 40], Some(2.5));
856 let batch3 = make_batch(&[0], &[50], Some(0.25));
857
858 let shuffler = TwoFileShuffler::new(output_dir, num_partitions);
859 let stream = batches_to_stream(vec![batch1, batch2, batch3]);
860 let reader = shuffler.shuffle(stream).await.unwrap();
861
862 let loss = reader.total_loss().unwrap();
863 assert!((loss - 4.25).abs() < 1e-10, "expected 4.25, got {}", loss);
864 }
865
866 #[tokio::test]
867 async fn test_two_file_shuffler_single_batch() {
868 let dir = TempStrDir::default();
869 let output_dir = Path::from(dir.as_ref());
870 let num_partitions = 2;
871
872 let batch = make_batch(&[1, 0], &[100, 200], Some(3.0));
873
874 let shuffler = TwoFileShuffler::new(output_dir, num_partitions);
875 let stream = batches_to_stream(vec![batch]);
876 let reader = shuffler.shuffle(stream).await.unwrap();
877
878 assert_eq!(reader.partition_size(0).unwrap(), 1);
879 assert_eq!(reader.partition_size(1).unwrap(), 1);
880
881 let p0 = collect_partition(reader.as_ref(), 0).await.unwrap();
882 let vals: &Int32Array = p0.column_by_name("val").unwrap().as_primitive();
883 assert_eq!(vals.value(0), 200);
884
885 let p1 = collect_partition(reader.as_ref(), 1).await.unwrap();
886 let vals: &Int32Array = p1.column_by_name("val").unwrap().as_primitive();
887 assert_eq!(vals.value(0), 100);
888
889 assert!((reader.total_loss().unwrap() - 3.0).abs() < 1e-10);
890 }
891
892 #[tokio::test]
893 async fn test_two_file_shuffler_multiple_batches() {
894 let dir = TempStrDir::default();
895 let output_dir = Path::from(dir.as_ref());
896 let num_partitions = 3;
897
898 let batch1 = make_batch(&[0, 1, 2], &[10, 20, 30], Some(1.0));
902 let batch2 = make_batch(&[2, 0, 1], &[40, 50, 60], Some(2.0));
903 let batch3 = make_batch(&[1, 2, 0], &[70, 80, 90], Some(3.0));
904
905 let shuffler = TwoFileShuffler::new(output_dir, num_partitions)
906 .with_batch_size_bytes(16);
908 let stream = batches_to_stream(vec![batch1, batch2, batch3]);
909 let reader = shuffler.shuffle(stream).await.unwrap();
910
911 assert_eq!(reader.partition_size(0).unwrap(), 3);
913 let p0 = collect_partition(reader.as_ref(), 0).await.unwrap();
914 let vals: &Int32Array = p0.column_by_name("val").unwrap().as_primitive();
915 let mut v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
916 v.sort();
917 assert_eq!(v, vec![10, 50, 90]);
918
919 assert_eq!(reader.partition_size(1).unwrap(), 3);
921 let p1 = collect_partition(reader.as_ref(), 1).await.unwrap();
922 let vals: &Int32Array = p1.column_by_name("val").unwrap().as_primitive();
923 let mut v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
924 v.sort();
925 assert_eq!(v, vec![20, 60, 70]);
926
927 assert_eq!(reader.partition_size(2).unwrap(), 3);
929 let p2 = collect_partition(reader.as_ref(), 2).await.unwrap();
930 let vals: &Int32Array = p2.column_by_name("val").unwrap().as_primitive();
931 let mut v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
932 v.sort();
933 assert_eq!(v, vec![30, 40, 80]);
934
935 assert!((reader.total_loss().unwrap() - 6.0).abs() < 1e-10);
936 }
937}