1use std::collections::HashMap;
15use std::sync::Arc;
16
17use arrow::array::{
18 ArrayBuilder, FixedSizeListBuilder, StructBuilder, UInt32Builder, UInt64Builder, UInt8Builder,
19};
20use arrow::buffer::{OffsetBuffer, ScalarBuffer};
21use arrow::compute::sort_to_indices;
22use arrow::datatypes::UInt32Type;
23use arrow_array::{cast::AsArray, types::UInt64Type, Array, RecordBatch, UInt32Array};
24use arrow_array::{FixedSizeListArray, UInt8Array};
25use arrow_array::{ListArray, StructArray, UInt64Array};
26use arrow_schema::{DataType, Field, Fields};
27use futures::stream::repeat_with;
28use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt};
29use lance_arrow::RecordBatchExt;
30use lance_core::cache::LanceCache;
31use lance_core::utils::tokio::get_num_compute_intensive_cpus;
32use lance_core::{datatypes::Schema, Error, Result, ROW_ID};
33use lance_encoding::decoder::{DecoderPlugins, FilterExpression};
34use lance_file::reader::FileReader;
35use lance_file::v2::reader::{FileReader as Lancev2FileReader, FileReaderOptions};
36use lance_file::v2::writer::FileWriterOptions;
37use lance_file::writer::FileWriter;
38use lance_io::object_store::ObjectStore;
39use lance_io::scheduler::{ScanScheduler, SchedulerConfig};
40use lance_io::stream::RecordBatchStream;
41use lance_io::utils::CachedFileSize;
42use lance_io::ReadBatchParams;
43use lance_table::format::SelfDescribingFileReader;
44use lance_table::io::manifest::ManifestDescribing;
45use log::info;
46use object_store::path::Path;
47use snafu::location;
48use tempfile::TempDir;
49
50use crate::vector::ivf::IvfTransformer;
51use crate::vector::transform::Transformer;
52use crate::vector::PART_ID_COLUMN;
53
54const UNSORTED_BUFFER: &str = "unsorted.lance";
55const SHUFFLE_BATCH_SIZE: usize = 1024;
56
57fn get_temp_dir() -> Result<Path> {
58 let dir = TempDir::new()?.keep();
60 let tmp_dir_path = Path::from_filesystem_path(dir).map_err(|e| Error::IO {
61 source: Box::new(e),
62 location: location!(),
63 })?;
64 Ok(tmp_dir_path)
65}
66
67#[derive(Debug)]
73struct PartitionBuilder {
74 builder: StructBuilder,
75}
76
77fn make_builder(datatype: &DataType, capacity: usize) -> Box<dyn ArrayBuilder> {
83 if let DataType::FixedSizeList(inner, dim) = datatype {
84 let inner_builder =
85 arrow_array::builder::make_builder(inner.data_type(), capacity * (*dim) as usize);
86 Box::new(FixedSizeListBuilder::new(inner_builder, *dim))
87 } else {
88 arrow_array::builder::make_builder(datatype, capacity)
89 }
90}
91
92fn from_fields(fields: impl Into<Fields>, capacity: usize) -> StructBuilder {
94 let fields = fields.into();
95 let mut builders = Vec::with_capacity(fields.len());
96 for field in &fields {
97 builders.push(make_builder(field.data_type(), capacity));
98 }
99 StructBuilder::new(fields, builders)
100}
101
102impl PartitionBuilder {
103 fn new(schema: &arrow_schema::Schema, initial_capacity: usize) -> Self {
104 let builder = from_fields(schema.fields.clone(), initial_capacity);
105 Self { builder }
106 }
107
108 fn extend(&mut self, batch: &RecordBatch) {
109 for _ in 0..batch.num_rows() {
110 self.builder.append(true);
111 }
112 let schema = batch.schema_ref();
113 for (field_idx, (col, field)) in batch.columns().iter().zip(schema.fields()).enumerate() {
114 match field.data_type() {
115 DataType::UInt32 => {
116 let col = col.as_any().downcast_ref::<UInt32Array>().unwrap();
117 self.builder
118 .field_builder::<UInt32Builder>(field_idx)
119 .unwrap()
120 .append_slice(col.values());
121 }
122 DataType::UInt64 => {
123 let col = col.as_any().downcast_ref::<UInt64Array>().unwrap();
124 self.builder
125 .field_builder::<UInt64Builder>(field_idx)
126 .unwrap()
127 .append_slice(col.values());
128 }
129 DataType::FixedSizeList(inner, _) => {
130 let col = col.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
131 match inner.data_type() {
132 DataType::UInt8 => {
133 let values =
134 col.values().as_any().downcast_ref::<UInt8Array>().unwrap();
135 let fsl_builder = self
136 .builder
137 .field_builder::<FixedSizeListBuilder<Box<dyn ArrayBuilder>>>(
138 field_idx,
139 )
140 .unwrap();
141 for _ in 0..col.len() {
143 fsl_builder.append(true);
144 }
145 fsl_builder
146 .values()
147 .as_any_mut()
148 .downcast_mut::<UInt8Builder>()
149 .unwrap()
150 .append_slice(values.values());
151 }
152 _ => panic!("Unexpected fixed size list item type in shuffled file"),
153 }
154 }
155 _ => panic!("Unexpected column type in shuffled file"),
156 }
157 }
158 }
159
160 fn finish(mut self) -> Result<ListArray> {
162 let struct_array = Arc::new(self.builder.finish());
163
164 let item_field = Arc::new(Field::new("item", struct_array.data_type().clone(), true));
165
166 Ok(ListArray::try_new(
167 item_field,
168 OffsetBuffer::new(ScalarBuffer::<i32>::from(vec![
169 0,
170 struct_array.len() as i32,
171 ])),
172 struct_array,
173 None,
174 )?)
175 }
176}
177
178struct PartitionListBuilder {
179 partitions: Vec<Option<PartitionBuilder>>,
180 partition_sizes: Vec<u64>,
181}
182
183impl PartitionListBuilder {
184 fn new(partition_sizes: Vec<u64>) -> Self {
185 Self {
186 partitions: Vec::default(),
187 partition_sizes,
188 }
189 }
190
191 fn extend(&mut self, batch: &RecordBatch) {
192 if batch.num_rows() == 0 {
193 return;
194 }
195
196 if self.partitions.is_empty() {
197 let schema = batch.schema();
198 self.partitions = Vec::from_iter(self.partition_sizes.iter().map(|part_size| {
199 if *part_size == 0 {
200 None
201 } else {
202 Some(PartitionBuilder::new(schema.as_ref(), *part_size as usize))
203 }
204 }))
205 }
206
207 let part_ids = batch[PART_ID_COLUMN].as_primitive::<UInt32Type>();
208
209 let part_id = part_ids.value(0) as usize;
210
211 let builder = &mut self.partitions[part_id];
212 builder
213 .as_mut()
214 .expect("partition size was zero but received data for partition")
215 .extend(batch);
216 }
217
218 fn finish(self) -> Result<Vec<ListArray>> {
219 self.partitions
220 .into_iter()
221 .filter_map(|builder| builder.map(|builder| builder.finish()))
222 .collect()
223 }
224}
225
226#[allow(clippy::too_many_arguments)]
245pub async fn shuffle_dataset(
246 data: impl RecordBatchStream + Unpin + 'static,
247 ivf: Arc<IvfTransformer>,
248 precomputed_partitions: Option<HashMap<u64, u32>>,
249 num_partitions: u32,
250 shuffle_partition_batches: usize,
251 shuffle_partition_concurrency: usize,
252 precomputed_shuffle_buffers: Option<(Path, Vec<String>)>,
253) -> Result<Vec<impl Stream<Item = Result<RecordBatch>>>> {
254 let shuffler = if let Some((path, buffers)) = precomputed_shuffle_buffers {
256 info!("Precomputed shuffle files provided, skip calculation of IVF partition.");
257 let mut shuffler = IvfShuffler::try_new(num_partitions, Some(path), true, None)?;
258 unsafe {
259 shuffler.set_unsorted_buffers(&buffers);
260 }
261
262 shuffler
263 } else {
264 info!(
265 "Calculating IVF partitions for vectors (num_partitions={}, precomputed_partitions={})",
266 num_partitions,
267 precomputed_partitions.is_some()
268 );
269 let mut shuffler = IvfShuffler::try_new(num_partitions, None, true, None)?;
270
271 let precomputed_partitions = precomputed_partitions.map(Arc::new);
272 let stream = data
273 .zip(repeat_with(move || ivf.clone()))
274 .map(move |(b, ivf)| {
275 let partition_map = precomputed_partitions
278 .as_ref()
279 .cloned()
280 .unwrap_or(Arc::new(HashMap::new()));
281
282 tokio::task::spawn(async move {
283 let mut batch = b?;
284
285 if !partition_map.is_empty() {
286 let row_ids = batch.column_by_name(ROW_ID).ok_or(Error::Index {
287 message: "column does not exist".to_string(),
288 location: location!(),
289 })?;
290 let part_ids = UInt32Array::from_iter(
291 row_ids
292 .as_primitive::<UInt64Type>()
293 .values()
294 .iter()
295 .map(|row_id| partition_map.get(row_id).copied()),
296 );
297 let part_ids = UInt32Array::from(part_ids);
298 batch = batch
299 .try_with_column(
300 Field::new(PART_ID_COLUMN, part_ids.data_type().clone(), true),
301 Arc::new(part_ids.clone()),
302 )
303 .expect("failed to add part id column");
304
305 if part_ids.null_count() > 0 {
306 info!(
307 "Filter out rows without valid partition IDs: null_count={}",
308 part_ids.null_count()
309 );
310 let indices = UInt32Array::from_iter(
311 part_ids
312 .iter()
313 .enumerate()
314 .filter_map(|(idx, v)| v.map(|_| idx as u32)),
315 );
316 assert_eq!(indices.len(), batch.num_rows() - part_ids.null_count());
317 batch = batch.take(&indices)?;
318 }
319 }
320 ivf.transform(&batch)
321 })
322 })
323 .buffer_unordered(get_num_compute_intensive_cpus())
324 .map(|res| match res {
325 Ok(Ok(batch)) => Ok(batch),
326 Ok(Err(err)) => Err(Error::io(err.to_string(), location!())),
327 Err(err) => Err(Error::io(err.to_string(), location!())),
328 })
329 .boxed();
330
331 let start = std::time::Instant::now();
332 shuffler.write_unsorted_stream(stream).await?;
333 info!(
334 "wrote partition assignment to unsorted tmp file in {:?}",
335 start.elapsed()
336 );
337
338 shuffler
339 };
340
341 let start = std::time::Instant::now();
343 let partition_files = shuffler
344 .write_partitioned_shuffles(shuffle_partition_batches, shuffle_partition_concurrency)
345 .await?;
346 info!("created sorted chunks in {:?}", start.elapsed());
347
348 let start = std::time::Instant::now();
350 let stream =
351 IvfShuffler::load_partitioned_shuffles(&shuffler.output_dir, partition_files).await?;
352 info!("merged partitioned shuffles in {:?}", start.elapsed());
353
354 Ok(stream)
355}
356
357pub async fn shuffle_vectors(
358 unsorted_filenames: Vec<String>,
359 dir_path: Path,
360 ivf_centroids: FixedSizeListArray,
361 shuffle_output_root_filename: &str,
362) -> Result<Vec<String>> {
363 let num_partitions = ivf_centroids.len() as u32;
364 let shuffle_partition_batches = SHUFFLE_BATCH_SIZE * 10;
365 let shuffle_partition_concurrency = 2;
366 let mut shuffler = IvfShuffler::try_new(
367 num_partitions,
368 Some(dir_path),
369 false,
370 Some(shuffle_output_root_filename.to_string()),
371 )?;
372
373 unsafe {
374 shuffler.set_unsorted_buffers(&unsorted_filenames);
375 }
376
377 let partition_files = shuffler
378 .write_partitioned_shuffles(shuffle_partition_batches, shuffle_partition_concurrency)
379 .await?;
380
381 Ok(partition_files)
382}
383
384#[derive(Clone)]
385pub struct IvfShuffler {
386 unsorted_buffers: Vec<String>,
387
388 num_partitions: u32,
389
390 output_dir: Path,
391
392 is_legacy: bool,
394
395 shuffle_output_root_filename: String,
396}
397
398struct ShuffleInput {
400 file_idx: usize,
402 start: usize,
404 end: usize,
406}
407
408impl IvfShuffler {
409 pub fn try_new(
410 num_partitions: u32,
411 output_dir: Option<Path>,
412 is_legacy: bool,
413 shuffle_output_root_filename: Option<String>,
414 ) -> Result<Self> {
415 let output_dir = match output_dir {
416 Some(output_dir) => output_dir,
417 None => get_temp_dir()?,
418 };
419
420 let shuffle_output_root_filename = match shuffle_output_root_filename {
421 Some(shuffle_output_root_filename) => shuffle_output_root_filename,
422 None => "sorted".to_string(),
423 };
424
425 Ok(Self {
426 num_partitions,
427 output_dir,
428 unsorted_buffers: vec![],
429 is_legacy,
430 shuffle_output_root_filename,
431 })
432 }
433
434 pub unsafe fn set_unsorted_buffers(&mut self, unsorted_buffers: &[impl ToString]) {
440 self.unsorted_buffers = unsorted_buffers.iter().map(|x| x.to_string()).collect();
441 }
442
443 pub async fn write_unsorted_stream(
444 &mut self,
445 data: impl Stream<Item = Result<RecordBatch>>,
446 ) -> Result<()> {
447 let object_store = ObjectStore::local();
448 let path = self.output_dir.child(UNSORTED_BUFFER);
449 let writer = object_store.create(&path).await?;
450
451 let mut data = Box::pin(data.peekable());
452 let schema = match data.as_mut().peek().await {
453 Some(Ok(batch)) => batch.schema(),
454 Some(Err(err)) => {
455 return Err(Error::io(err.to_string(), location!()));
456 }
457 None => {
458 return Err(Error::io("empty stream".to_string(), location!()));
459 }
460 };
461
462 schema
465 .column_with_name(ROW_ID)
466 .ok_or(Error::io("row ID column not found".to_owned(), location!()))?;
467 schema.column_with_name(PART_ID_COLUMN).ok_or(Error::io(
468 "partition ID column not found".to_owned(),
469 location!(),
470 ))?;
471
472 info!("Writing unsorted data to disk at {}", path);
473 info!("with schema: {:?}", schema);
474
475 let mut file_writer = FileWriter::<ManifestDescribing>::with_object_writer(
476 writer,
477 Schema::try_from(schema.as_ref())?,
478 &Default::default(),
479 )?;
480
481 let mut batches_processed = 0;
482 while let Some(batch) = data.next().await {
483 if batches_processed % 1000 == 0 {
484 info!("Partition assignment progress {}/?", batches_processed);
485 }
486 batches_processed += 1;
487 file_writer.write(&[batch?]).await?;
488 }
489
490 file_writer.finish().await?;
491
492 unsafe {
493 self.set_unsorted_buffers(&[UNSORTED_BUFFER]);
494 }
495
496 Ok(())
497 }
498
499 async fn total_batches(&self) -> Result<Vec<usize>> {
500 let mut total_batches = vec![];
501 for buffer in &self.unsorted_buffers {
502 let object_store = ObjectStore::local();
503 let path = self.output_dir.child(buffer.as_str());
504
505 if self.is_legacy {
506 let reader = FileReader::try_new_self_described(&object_store, &path, None).await?;
507 total_batches.push(reader.num_batches());
508 } else {
509 let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
510 let scheduler = ScanScheduler::new(object_store.into(), scheduler_config);
511 let file = scheduler
512 .open_file(&path, &CachedFileSize::unknown())
513 .await?;
514 let cache = LanceCache::with_capacity(128 * 1024 * 1024);
515
516 let reader = Lancev2FileReader::try_open(
517 file,
518 None,
519 Default::default(),
520 &cache,
521 FileReaderOptions::default(),
522 )
523 .await?;
524 let num_batches = reader.metadata().num_rows / (SHUFFLE_BATCH_SIZE as u64);
525 total_batches.push(num_batches as usize);
526 }
527 }
528 Ok(total_batches)
529 }
530
531 async fn count_partition_size(&self, inputs: &[ShuffleInput]) -> Result<Vec<u64>> {
532 let object_store = ObjectStore::local();
533 let mut partition_sizes = vec![0; self.num_partitions as usize];
534 let scheduler = ScanScheduler::new(
535 Arc::new(object_store.clone()),
536 SchedulerConfig::max_bandwidth(&object_store),
537 );
538
539 for &ShuffleInput {
540 file_idx,
541 start,
542 end,
543 } in inputs
544 {
545 let file_name = &self.unsorted_buffers[file_idx];
546 let path = self.output_dir.child(file_name.as_str());
547
548 if self.is_legacy {
549 let reader = FileReader::try_new_self_described(&object_store, &path, None).await?;
550 let lance_schema = reader
551 .schema()
552 .project(&[PART_ID_COLUMN])
553 .expect("part id should exist");
554
555 let mut stream = stream::iter(start..end)
556 .map(|i| reader.read_batch(i as i32, .., &lance_schema))
557 .buffer_unordered(16);
558
559 while let Some(batch) = stream.next().await {
560 let batch = batch?;
561 let part_ids: &UInt32Array = batch
562 .column_by_name(PART_ID_COLUMN)
563 .expect("Partition ID column not found")
564 .as_primitive();
565 part_ids.values().iter().for_each(|part_id| {
566 partition_sizes[*part_id as usize] += 1;
567 });
568 }
569 } else {
570 let file = scheduler
571 .open_file(&path, &CachedFileSize::unknown())
572 .await?;
573 let reader = Lancev2FileReader::try_open(
574 file,
575 None,
576 Default::default(),
577 &LanceCache::no_cache(),
578 FileReaderOptions::default(),
579 )
580 .await?;
581 let mut stream = reader
582 .read_stream(
583 lance_io::ReadBatchParams::Range(
584 (start * SHUFFLE_BATCH_SIZE)..(end * SHUFFLE_BATCH_SIZE),
585 ),
586 SHUFFLE_BATCH_SIZE as u32,
587 16,
588 FilterExpression::no_filter(),
589 )
590 .unwrap();
591
592 while let Some(batch) = stream.next().await {
593 let batch = batch?;
594 let part_ids: &UInt32Array = batch
595 .column_by_name(PART_ID_COLUMN)
596 .expect("Partition ID column not found")
597 .as_primitive();
598 part_ids.values().iter().for_each(|part_id| {
599 partition_sizes[*part_id as usize] += 1;
600 });
601 }
602 }
603 }
604
605 Ok(partition_sizes)
606 }
607
608 async fn shuffle_to_partitions(
609 &self,
610 inputs: &[ShuffleInput],
611 partition_size: Vec<u64>,
612 num_batches_to_sort: usize,
613 ) -> Result<Vec<ListArray>> {
614 info!("Shuffling into memory");
615
616 let mut num_processed = 0;
617 let mut partitions_builder = PartitionListBuilder::new(partition_size);
618
619 for &ShuffleInput {
620 file_idx,
621 start,
622 end,
623 } in inputs
624 {
625 let object_store = ObjectStore::local();
626 let file_name = &self.unsorted_buffers[file_idx];
627 let path = self.output_dir.child(file_name.as_str());
628 let mut _reader_handle = None;
629
630 let mut stream = if self.is_legacy {
631 _reader_handle =
632 Some(FileReader::try_new_self_described(&object_store, &path, None).await?);
633
634 stream::iter(start..end)
635 .map(|i| {
636 let reader = _reader_handle.as_ref().unwrap();
637 reader.read_batch(i as i32, ReadBatchParams::RangeFull, reader.schema())
638 })
639 .buffered(16)
640 .boxed()
641 } else {
642 let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
643 let scheduler = ScanScheduler::new(Arc::new(object_store), scheduler_config);
644 let file = scheduler
645 .open_file(&path, &CachedFileSize::unknown())
646 .await?;
647 let reader = Lancev2FileReader::try_open(
648 file,
649 None,
650 Default::default(),
651 &LanceCache::no_cache(),
652 FileReaderOptions::default(),
653 )
654 .await?;
655 reader
656 .read_stream(
657 lance_io::ReadBatchParams::Range(
658 (start * SHUFFLE_BATCH_SIZE)..(end * SHUFFLE_BATCH_SIZE),
659 ),
660 SHUFFLE_BATCH_SIZE as u32,
661 16,
662 FilterExpression::no_filter(),
663 )?
664 .boxed()
665 };
666
667 while let Some(batch) = stream.next().await {
668 if num_processed % 100 == 0 {
669 info!("Shuffle Progress {}/{}", num_processed, num_batches_to_sort);
670 }
671 num_processed += 1;
672
673 let batch = batch?;
674
675 if batch.num_rows() == 0 {
676 continue;
677 }
678
679 let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
680 let indices = sort_to_indices(&part_ids, None, None)?;
681 let batch = batch.take(&indices)?;
682
683 let sorted_part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
684
685 let mut start = 0;
686 let mut prev_id = sorted_part_ids.value(0);
687 for (idx, part_id) in sorted_part_ids.values().iter().enumerate() {
688 if *part_id != prev_id {
689 partitions_builder.extend(&batch.slice(start, idx - start));
690 start = idx;
691 prev_id = *part_id;
692 }
693 }
694 partitions_builder.extend(&batch.slice(start, sorted_part_ids.len() - start));
695 }
696 }
697
698 partitions_builder.finish()
699 }
700
701 pub async fn write_partitioned_shuffles(
702 &self,
703 batches_per_partition: usize,
704 concurrent_jobs: usize,
705 ) -> Result<Vec<String>> {
706 let num_batches = self.total_batches().await?;
707 let total_batches = num_batches.iter().sum();
708 info!(
709 "Sorting unsorted data into sorted chunks (batches_per_chunk={} concurrent_jobs={})",
710 batches_per_partition, concurrent_jobs
711 );
712 stream::iter((0..total_batches).step_by(batches_per_partition))
713 .zip(stream::repeat(num_batches))
714 .map(|(i, num_batches)| {
715 let this = self.clone();
716 tokio::spawn(async move {
717 let start = i;
719 let end = std::cmp::min(i + batches_per_partition, total_batches);
720 let num_batches_to_sort = end - start;
721 let mut input = vec![];
722
723 let mut cumulative_size = 0;
724 for (file_idx, partition_size) in num_batches.iter().enumerate() {
725 let cur_start = cumulative_size;
726 let cur_end = cumulative_size + partition_size;
727
728 cumulative_size += partition_size;
729
730 let should_include_file = start < cur_end && end > cur_start;
731
732 if !should_include_file {
733 continue;
734 }
735
736 if start >= cur_end {
738 continue;
739 }
740
741 let local_start = start.saturating_sub(cur_start);
742 let local_end = std::cmp::min(end - cur_start, *partition_size);
743
744 input.push(ShuffleInput {
745 file_idx,
746 start: local_start,
747 end: local_end,
748 });
749 }
750
751 let size_counts = this.count_partition_size(&input).await?;
753
754 let shuffled = this
756 .shuffle_to_partitions(&input, size_counts, num_batches_to_sort)
757 .await?;
758
759 let object_store = ObjectStore::local();
761 let output_file = format!("{}_{}.lance", this.shuffle_output_root_filename, i);
762 let path = this.output_dir.child(output_file.clone());
763 let writer = object_store.create(&path).await?;
764
765 info!(
766 "Chunk loaded into memory and sorted, writing to disk at {}",
767 path
768 );
769
770 let sorted_file_schema = Arc::new(arrow_schema::Schema::new(vec![Field::new(
771 "partitions",
772 shuffled.first().unwrap().data_type().clone(),
773 true,
774 )]));
775 let lance_schema = Schema::try_from(sorted_file_schema.as_ref())?;
776 let mut file_writer = lance_file::v2::writer::FileWriter::try_new(
777 writer,
778 lance_schema,
779 FileWriterOptions::default(),
780 )?;
781
782 for partition_and_idx in shuffled.into_iter().enumerate() {
783 let (idx, partition) = partition_and_idx;
784 if idx % 1000 == 0 {
785 info!("Writing partition {}/{}", idx, this.num_partitions);
786 }
787 let batch = RecordBatch::try_new(
788 sorted_file_schema.clone(),
789 vec![Arc::new(partition)],
790 )?;
791 file_writer.write_batch(&batch).await?;
792 }
793
794 file_writer.finish().await?;
795
796 Ok(output_file) as Result<String>
797 })
798 .map(|join_res| join_res.unwrap())
799 })
800 .buffered(concurrent_jobs)
801 .try_collect()
802 .await
803 }
804
805 pub async fn load_partitioned_shuffles(
806 basedir: &Path,
807 files: Vec<String>,
808 ) -> Result<Vec<impl Stream<Item = Result<RecordBatch>>>> {
809 let mut streams = vec![];
811
812 for file in files {
813 let object_store = Arc::new(ObjectStore::local());
814 let path = basedir.child(file);
815 let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
816 let scan_scheduler = ScanScheduler::new(object_store, scheduler_config);
817 let file_scheduler = scan_scheduler
818 .open_file(&path, &CachedFileSize::unknown())
819 .await?;
820 let reader = lance_file::v2::reader::FileReader::try_open(
821 file_scheduler,
822 None,
823 Arc::<DecoderPlugins>::default(),
824 &LanceCache::no_cache(),
825 FileReaderOptions::default(),
826 )
827 .await?;
828 let stream = reader
829 .read_stream(
830 ReadBatchParams::RangeFull,
831 1,
832 32,
833 FilterExpression::no_filter(),
834 )?
835 .and_then(|batch| {
836 let list_array = batch
837 .column(0)
838 .as_any()
839 .downcast_ref::<ListArray>()
840 .expect("ListArray expected");
841 let struct_array = list_array
842 .values()
843 .as_any()
844 .downcast_ref::<StructArray>()
845 .expect("StructArray expected")
846 .clone();
847 let batch: RecordBatch = struct_array.into();
848 std::future::ready(Ok(batch))
849 });
850
851 streams.push(stream);
852 }
853
854 Ok(streams)
855 }
856}
857
858#[cfg(test)]
859mod test {
860 use arrow_array::{
861 types::{UInt32Type, UInt8Type},
862 FixedSizeListArray, UInt64Array, UInt8Array,
863 };
864 use arrow_schema::DataType;
865 use lance_arrow::FixedSizeListArrayExt;
866 use lance_core::ROW_ID_FIELD;
867 use lance_io::stream::RecordBatchStreamAdapter;
868 use rand::RngCore;
869
870 use crate::vector::PQ_CODE_COLUMN;
871
872 use super::*;
873
874 fn make_schema(pq_dim: u32) -> Arc<arrow_schema::Schema> {
875 Arc::new(arrow_schema::Schema::new(vec![
876 ROW_ID_FIELD.clone(),
877 arrow_schema::Field::new(PART_ID_COLUMN, DataType::UInt32, true),
878 arrow_schema::Field::new(
879 PQ_CODE_COLUMN,
880 DataType::FixedSizeList(
881 Arc::new(arrow_schema::Field::new("item", DataType::UInt8, true)),
882 pq_dim as i32,
883 ),
884 false,
885 ),
886 ]))
887 }
888
889 fn make_stream_and_shuffler(
890 include_empty_batches: bool,
891 ) -> (impl RecordBatchStream, IvfShuffler) {
892 let schema = make_schema(32);
893
894 let schema2 = schema.clone();
895
896 let stream =
897 stream::iter(0..if include_empty_batches { 101 } else { 100 }).map(move |idx| {
898 if include_empty_batches && idx == 100 {
899 return Ok(RecordBatch::try_new(
900 schema2.clone(),
901 vec![
902 Arc::new(UInt64Array::from_iter_values([])),
903 Arc::new(UInt32Array::from_iter_values([])),
904 Arc::new(
905 FixedSizeListArray::try_new_from_values(
906 Arc::new(UInt8Array::from_iter_values([])) as Arc<dyn Array>,
907 32,
908 )
909 .unwrap(),
910 ),
911 ],
912 )
913 .unwrap());
914 }
915 let start_idx = idx * (SHUFFLE_BATCH_SIZE as u64);
916 let end_idx = (idx + 1) * (SHUFFLE_BATCH_SIZE as u64);
917 let row_ids = Arc::new(UInt64Array::from_iter(start_idx..end_idx));
918
919 let part_id = Arc::new(UInt32Array::from_iter(
920 (start_idx..end_idx).map(|_| idx as u32),
921 ));
922
923 let values = Arc::new(UInt8Array::from_iter(
924 (0..32 * SHUFFLE_BATCH_SIZE).map(|_| idx as u8),
925 ));
926 let pq_codes = Arc::new(
927 FixedSizeListArray::try_new_from_values(values as Arc<dyn Array>, 32).unwrap(),
928 );
929
930 Ok(
931 RecordBatch::try_new(schema2.clone(), vec![row_ids, part_id, pq_codes])
932 .unwrap(),
933 )
934 });
935
936 let stream = RecordBatchStreamAdapter::new(schema, stream);
937
938 let shuffler = IvfShuffler::try_new(100, None, true, None).unwrap();
939
940 (stream, shuffler)
941 }
942
943 fn check_batch(batch: RecordBatch, idx: usize, num_rows: usize) {
944 let row_ids = batch
945 .column_by_name(ROW_ID)
946 .unwrap()
947 .as_primitive::<UInt64Type>();
948 let part_ids = batch
949 .column_by_name(PART_ID_COLUMN)
950 .unwrap()
951 .as_primitive::<UInt32Type>();
952 let pq_codes = batch
953 .column_by_name(PQ_CODE_COLUMN)
954 .unwrap()
955 .as_fixed_size_list()
956 .values()
957 .as_primitive::<UInt8Type>();
958
959 assert_eq!(row_ids.len(), num_rows);
960 assert_eq!(part_ids.len(), num_rows);
961 assert_eq!(pq_codes.len(), num_rows * 32);
962
963 for i in 0..num_rows {
964 assert_eq!(part_ids.value(i), idx as u32);
965 }
966
967 for v in pq_codes.values() {
968 assert_eq!(*v, idx as u8);
969 }
970 }
971
972 #[tokio::test]
973 async fn test_shuffler_single_partition() {
974 let (stream, mut shuffler) = make_stream_and_shuffler(false);
975
976 shuffler.write_unsorted_stream(stream).await.unwrap();
977 let partition_files = shuffler.write_partitioned_shuffles(100, 1).await.unwrap();
978
979 assert_eq!(partition_files.len(), 1);
980
981 let mut result_stream =
982 IvfShuffler::load_partitioned_shuffles(&shuffler.output_dir, partition_files)
983 .await
984 .unwrap();
985
986 let mut num_batches = 0;
987 let mut stream = result_stream.pop().unwrap();
988
989 while let Some(item) = stream.next().await {
990 check_batch(item.unwrap(), num_batches, SHUFFLE_BATCH_SIZE);
991 num_batches += 1;
992 }
993
994 assert_eq!(num_batches, 100);
995 }
996
997 #[tokio::test]
998 async fn test_shuffler_single_partition_with_empty_batch() {
999 let (stream, mut shuffler) = make_stream_and_shuffler(true);
1000
1001 shuffler.write_unsorted_stream(stream).await.unwrap();
1002 let partition_files = shuffler.write_partitioned_shuffles(101, 1).await.unwrap();
1003
1004 assert_eq!(partition_files.len(), 1);
1005
1006 let mut result_stream =
1007 IvfShuffler::load_partitioned_shuffles(&shuffler.output_dir, partition_files)
1008 .await
1009 .unwrap();
1010
1011 let mut num_batches = 0;
1012 let mut stream = result_stream.pop().unwrap();
1013
1014 while let Some(item) = stream.next().await {
1015 check_batch(item.unwrap(), num_batches, SHUFFLE_BATCH_SIZE);
1016 num_batches += 1;
1017 }
1018
1019 assert_eq!(num_batches, 100);
1020 }
1021
1022 #[tokio::test]
1023 async fn test_shuffler_multiple_partition() {
1024 let (stream, mut shuffler) = make_stream_and_shuffler(false);
1025
1026 shuffler.write_unsorted_stream(stream).await.unwrap();
1027 let partition_files = shuffler.write_partitioned_shuffles(1, 100).await.unwrap();
1028
1029 assert_eq!(partition_files.len(), 100);
1030
1031 let mut result_stream =
1032 IvfShuffler::load_partitioned_shuffles(&shuffler.output_dir, partition_files)
1033 .await
1034 .unwrap();
1035
1036 let mut num_batches = 0;
1037 result_stream.reverse();
1038
1039 while let Some(mut stream) = result_stream.pop() {
1040 while let Some(item) = stream.next().await {
1041 check_batch(item.unwrap(), num_batches, SHUFFLE_BATCH_SIZE);
1042 num_batches += 1
1043 }
1044 }
1045
1046 assert_eq!(num_batches, 100);
1047 }
1048
1049 #[tokio::test]
1050 async fn test_shuffler_multi_buffer_single_partition() {
1051 let (stream, mut shuffler) = make_stream_and_shuffler(false);
1052 shuffler.write_unsorted_stream(stream).await.unwrap();
1053
1054 unsafe { shuffler.set_unsorted_buffers(&[UNSORTED_BUFFER, UNSORTED_BUFFER]) }
1056
1057 let partition_files = shuffler.write_partitioned_shuffles(200, 1).await.unwrap();
1058
1059 assert_eq!(partition_files.len(), 1);
1060
1061 let mut result_stream =
1062 IvfShuffler::load_partitioned_shuffles(&shuffler.output_dir, partition_files)
1063 .await
1064 .unwrap();
1065
1066 let mut num_batches = 0;
1067 result_stream.reverse();
1068
1069 while let Some(mut stream) = result_stream.pop() {
1070 while let Some(item) = stream.next().await {
1071 check_batch(item.unwrap(), num_batches, 2048);
1072 num_batches += 1
1073 }
1074 }
1075
1076 assert_eq!(num_batches, 100);
1077 }
1078
1079 #[tokio::test]
1080 async fn test_shuffler_multi_buffer_multi_partition() {
1081 let (stream, mut shuffler) = make_stream_and_shuffler(false);
1082 shuffler.write_unsorted_stream(stream).await.unwrap();
1083
1084 unsafe { shuffler.set_unsorted_buffers(&[UNSORTED_BUFFER, UNSORTED_BUFFER]) }
1086
1087 let partition_files = shuffler.write_partitioned_shuffles(1, 32).await.unwrap();
1088 assert_eq!(partition_files.len(), 200);
1089
1090 let mut result_stream =
1091 IvfShuffler::load_partitioned_shuffles(&shuffler.output_dir, partition_files)
1092 .await
1093 .unwrap();
1094
1095 let mut num_batches = 0;
1096 result_stream.reverse();
1097
1098 while let Some(mut stream) = result_stream.pop() {
1099 while let Some(item) = stream.next().await {
1100 check_batch(item.unwrap(), num_batches % 100, SHUFFLE_BATCH_SIZE);
1101 num_batches += 1
1102 }
1103 }
1104
1105 assert_eq!(num_batches, 200);
1106 }
1107
1108 fn make_big_stream_and_shuffler(
1109 num_batches: u32,
1110 num_partitions: u32,
1111 pq_dim: u32,
1112 ) -> (impl RecordBatchStream, IvfShuffler) {
1113 let schema = make_schema(pq_dim);
1114
1115 let schema2 = schema.clone();
1116
1117 let stream = stream::iter(0..num_batches).map(move |idx| {
1118 let mut rng = rand::thread_rng();
1119 let row_ids = Arc::new(UInt64Array::from_iter(
1120 (idx * 1024..(idx + 1) * 1024).map(u64::from),
1121 ));
1122
1123 let part_id = Arc::new(UInt32Array::from_iter(
1124 (idx * 1024..(idx + 1) * 1024).map(|_| rng.next_u32() % num_partitions),
1125 ));
1126
1127 let values = Arc::new(UInt8Array::from_iter((0..pq_dim * 1024).map(|_| idx as u8)));
1128 let pq_codes = Arc::new(
1129 FixedSizeListArray::try_new_from_values(values as Arc<dyn Array>, pq_dim as i32)
1130 .unwrap(),
1131 );
1132
1133 Ok(RecordBatch::try_new(schema2.clone(), vec![row_ids, part_id, pq_codes]).unwrap())
1134 });
1135
1136 let stream = RecordBatchStreamAdapter::new(schema, stream);
1137
1138 let shuffler = IvfShuffler::try_new(num_partitions, None, true, None).unwrap();
1139
1140 (stream, shuffler)
1141 }
1142
1143 const NUM_BATCHES: u32 = 100;
1145 const NUM_PARTITIONS: u32 = 1000;
1146 const PQ_DIM: u32 = 48;
1147 const BATCHES_PER_PARTITION: u32 = 10200;
1148 const NUM_CONCURRENT_JOBS: u32 = 16;
1149
1150 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1151 async fn test_big_shuffle() {
1152 let (stream, mut shuffler) =
1153 make_big_stream_and_shuffler(NUM_BATCHES, NUM_PARTITIONS, PQ_DIM);
1154
1155 shuffler.write_unsorted_stream(stream).await.unwrap();
1156 let partition_files = shuffler
1157 .write_partitioned_shuffles(
1158 BATCHES_PER_PARTITION as usize,
1159 NUM_CONCURRENT_JOBS as usize,
1160 )
1161 .await
1162 .unwrap();
1163
1164 let expected_num_part_files = NUM_BATCHES.div_ceil(BATCHES_PER_PARTITION);
1165
1166 assert_eq!(partition_files.len(), expected_num_part_files as usize);
1167
1168 let mut result_stream =
1169 IvfShuffler::load_partitioned_shuffles(&shuffler.output_dir, partition_files)
1170 .await
1171 .unwrap();
1172
1173 let mut num_batches = 0;
1174 result_stream.reverse();
1175
1176 while let Some(mut stream) = result_stream.pop() {
1177 while (stream.next().await).is_some() {
1178 num_batches += 1
1179 }
1180 }
1181
1182 assert_eq!(num_batches, NUM_PARTITIONS * expected_num_part_files);
1183 }
1184}