lance_index/vector/ivf/
shuffler.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Disk-based shuffle a stream of [RecordBatch] into each IVF partition.
5//!
6//! 1. write the entire stream to a file
7//! 2. count the number of rows in each partition
8//! 3. read the data back into memory and shuffle into grouped vectors
9//!
10//! Problems for the future:
11//! 1. while groupby column will stay the same, we may want to include extra data columns in the future
12//! 2. shuffling into memory is fast but we should add disk buffer to support bigger datasets
13
14use std::collections::HashMap;
15use std::sync::Arc;
16
17use arrow::array::{
18    ArrayBuilder, FixedSizeListBuilder, StructBuilder, UInt32Builder, UInt64Builder, UInt8Builder,
19};
20use arrow::buffer::{OffsetBuffer, ScalarBuffer};
21use arrow::compute::sort_to_indices;
22use arrow::datatypes::UInt32Type;
23use arrow_array::{cast::AsArray, types::UInt64Type, Array, RecordBatch, UInt32Array};
24use arrow_array::{FixedSizeListArray, UInt8Array};
25use arrow_array::{ListArray, StructArray, UInt64Array};
26use arrow_schema::{DataType, Field, Fields};
27use futures::stream::repeat_with;
28use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt};
29use lance_arrow::RecordBatchExt;
30use lance_core::cache::LanceCache;
31use lance_core::utils::tokio::get_num_compute_intensive_cpus;
32use lance_core::{datatypes::Schema, Error, Result, ROW_ID};
33use lance_encoding::decoder::{DecoderPlugins, FilterExpression};
34use lance_file::reader::FileReader;
35use lance_file::v2::reader::{FileReader as Lancev2FileReader, FileReaderOptions};
36use lance_file::v2::writer::FileWriterOptions;
37use lance_file::writer::FileWriter;
38use lance_io::object_store::ObjectStore;
39use lance_io::scheduler::{ScanScheduler, SchedulerConfig};
40use lance_io::stream::RecordBatchStream;
41use lance_io::utils::CachedFileSize;
42use lance_io::ReadBatchParams;
43use lance_table::format::SelfDescribingFileReader;
44use lance_table::io::manifest::ManifestDescribing;
45use log::info;
46use object_store::path::Path;
47use snafu::location;
48use tempfile::TempDir;
49
50use crate::vector::ivf::IvfTransformer;
51use crate::vector::transform::Transformer;
52use crate::vector::PART_ID_COLUMN;
53
54const UNSORTED_BUFFER: &str = "unsorted.lance";
55const SHUFFLE_BATCH_SIZE: usize = 1024;
56
57fn get_temp_dir() -> Result<Path> {
58    // Note: using into_path here means we will not delete this TempDir automatically
59    let dir = TempDir::new()?.keep();
60    let tmp_dir_path = Path::from_filesystem_path(dir).map_err(|e| Error::IO {
61        source: Box::new(e),
62        location: location!(),
63    })?;
64    Ok(tmp_dir_path)
65}
66
67/// A builder for a partition of data
68///
69/// After we sort a batch of data into partitions we append those slices into this builder.
70///
71/// The builder is pre-allocated and so this extend operation should only be a memcpy
72#[derive(Debug)]
73struct PartitionBuilder {
74    builder: StructBuilder,
75}
76
77// Fork of arrow_array::builder::make_builder that handles FixedSizeList >_<
78//
79// Not really suitable for upstreaming because FixedSizeListBuilder<Box<dyn ArrayBuilder>> is
80// awkward and the entire make_builder function needs some overhaul (dyn ArrayBuilder should have
81// an extend(array: &dyn Array) method).
82fn make_builder(datatype: &DataType, capacity: usize) -> Box<dyn ArrayBuilder> {
83    if let DataType::FixedSizeList(inner, dim) = datatype {
84        let inner_builder =
85            arrow_array::builder::make_builder(inner.data_type(), capacity * (*dim) as usize);
86        Box::new(FixedSizeListBuilder::new(inner_builder, *dim))
87    } else {
88        arrow_array::builder::make_builder(datatype, capacity)
89    }
90}
91
92// Fork of StructBuilder::from_fields that handles FixedSizeList >_<
93fn from_fields(fields: impl Into<Fields>, capacity: usize) -> StructBuilder {
94    let fields = fields.into();
95    let mut builders = Vec::with_capacity(fields.len());
96    for field in &fields {
97        builders.push(make_builder(field.data_type(), capacity));
98    }
99    StructBuilder::new(fields, builders)
100}
101
102impl PartitionBuilder {
103    fn new(schema: &arrow_schema::Schema, initial_capacity: usize) -> Self {
104        let builder = from_fields(schema.fields.clone(), initial_capacity);
105        Self { builder }
106    }
107
108    fn extend(&mut self, batch: &RecordBatch) {
109        for _ in 0..batch.num_rows() {
110            self.builder.append(true);
111        }
112        let schema = batch.schema_ref();
113        for (field_idx, (col, field)) in batch.columns().iter().zip(schema.fields()).enumerate() {
114            match field.data_type() {
115                DataType::UInt32 => {
116                    let col = col.as_any().downcast_ref::<UInt32Array>().unwrap();
117                    self.builder
118                        .field_builder::<UInt32Builder>(field_idx)
119                        .unwrap()
120                        .append_slice(col.values());
121                }
122                DataType::UInt64 => {
123                    let col = col.as_any().downcast_ref::<UInt64Array>().unwrap();
124                    self.builder
125                        .field_builder::<UInt64Builder>(field_idx)
126                        .unwrap()
127                        .append_slice(col.values());
128                }
129                DataType::FixedSizeList(inner, _) => {
130                    let col = col.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
131                    match inner.data_type() {
132                        DataType::UInt8 => {
133                            let values =
134                                col.values().as_any().downcast_ref::<UInt8Array>().unwrap();
135                            let fsl_builder = self
136                                .builder
137                                .field_builder::<FixedSizeListBuilder<Box<dyn ArrayBuilder>>>(
138                                    field_idx,
139                                )
140                                .unwrap();
141                            // TODO: Upstream an append_many to FSL builder
142                            for _ in 0..col.len() {
143                                fsl_builder.append(true);
144                            }
145                            fsl_builder
146                                .values()
147                                .as_any_mut()
148                                .downcast_mut::<UInt8Builder>()
149                                .unwrap()
150                                .append_slice(values.values());
151                        }
152                        _ => panic!("Unexpected fixed size list item type in shuffled file"),
153                    }
154                }
155                _ => panic!("Unexpected column type in shuffled file"),
156            }
157        }
158    }
159
160    // Convert the partition builder into a list array with 1 row
161    fn finish(mut self) -> Result<ListArray> {
162        let struct_array = Arc::new(self.builder.finish());
163
164        let item_field = Arc::new(Field::new("item", struct_array.data_type().clone(), true));
165
166        Ok(ListArray::try_new(
167            item_field,
168            OffsetBuffer::new(ScalarBuffer::<i32>::from(vec![
169                0,
170                struct_array.len() as i32,
171            ])),
172            struct_array,
173            None,
174        )?)
175    }
176}
177
178struct PartitionListBuilder {
179    partitions: Vec<Option<PartitionBuilder>>,
180    partition_sizes: Vec<u64>,
181}
182
183impl PartitionListBuilder {
184    fn new(partition_sizes: Vec<u64>) -> Self {
185        Self {
186            partitions: Vec::default(),
187            partition_sizes,
188        }
189    }
190
191    fn extend(&mut self, batch: &RecordBatch) {
192        if batch.num_rows() == 0 {
193            return;
194        }
195
196        if self.partitions.is_empty() {
197            let schema = batch.schema();
198            self.partitions = Vec::from_iter(self.partition_sizes.iter().map(|part_size| {
199                if *part_size == 0 {
200                    None
201                } else {
202                    Some(PartitionBuilder::new(schema.as_ref(), *part_size as usize))
203                }
204            }))
205        }
206
207        let part_ids = batch[PART_ID_COLUMN].as_primitive::<UInt32Type>();
208
209        let part_id = part_ids.value(0) as usize;
210
211        let builder = &mut self.partitions[part_id];
212        builder
213            .as_mut()
214            .expect("partition size was zero but received data for partition")
215            .extend(batch);
216    }
217
218    fn finish(self) -> Result<Vec<ListArray>> {
219        self.partitions
220            .into_iter()
221            .filter_map(|builder| builder.map(|builder| builder.finish()))
222            .collect()
223    }
224}
225
226/// Disk-based shuffle for a stream of [RecordBatch] into each IVF partition.
227/// Sub-quantizer will be applied if provided.
228///
229/// Parameters
230/// ----------
231///   *data*: input data stream.
232///   *column*: column name of the vector column.
233///   *ivf*: IVF model.
234///   *num_partitions*: number of IVF partitions.
235///   *num_sub_vectors*: number of PQ sub-vectors.
236///
237/// Returns
238/// -------
239///   Result<Vec<impl Stream<Item = Result<RecordBatch>>>>: a vector of streams
240///   of shuffled partitioned data. Each stream corresponds to a partition and
241///   is sorted within the stream. Consumer of these streams is expected to merge
242///   the streams into a single stream by k-list merge algo.
243///
244#[allow(clippy::too_many_arguments)]
245pub async fn shuffle_dataset(
246    data: impl RecordBatchStream + Unpin + 'static,
247    ivf: Arc<IvfTransformer>,
248    precomputed_partitions: Option<HashMap<u64, u32>>,
249    num_partitions: u32,
250    shuffle_partition_batches: usize,
251    shuffle_partition_concurrency: usize,
252    precomputed_shuffle_buffers: Option<(Path, Vec<String>)>,
253) -> Result<Vec<impl Stream<Item = Result<RecordBatch>>>> {
254    // step 1: either use precomputed shuffle files or write shuffle data to a file
255    let shuffler = if let Some((path, buffers)) = precomputed_shuffle_buffers {
256        info!("Precomputed shuffle files provided, skip calculation of IVF partition.");
257        let mut shuffler = IvfShuffler::try_new(num_partitions, Some(path), true, None)?;
258        unsafe {
259            shuffler.set_unsorted_buffers(&buffers);
260        }
261
262        shuffler
263    } else {
264        info!(
265            "Calculating IVF partitions for vectors (num_partitions={}, precomputed_partitions={})",
266            num_partitions,
267            precomputed_partitions.is_some()
268        );
269        let mut shuffler = IvfShuffler::try_new(num_partitions, None, true, None)?;
270
271        let precomputed_partitions = precomputed_partitions.map(Arc::new);
272        let stream = data
273            .zip(repeat_with(move || ivf.clone()))
274            .map(move |(b, ivf)| {
275                // If precomputed_partitions map is provided, use it
276                // for fast partitions.
277                let partition_map = precomputed_partitions
278                    .as_ref()
279                    .cloned()
280                    .unwrap_or(Arc::new(HashMap::new()));
281
282                tokio::task::spawn(async move {
283                    let mut batch = b?;
284
285                    if !partition_map.is_empty() {
286                        let row_ids = batch.column_by_name(ROW_ID).ok_or(Error::Index {
287                            message: "column does not exist".to_string(),
288                            location: location!(),
289                        })?;
290                        let part_ids = UInt32Array::from_iter(
291                            row_ids
292                                .as_primitive::<UInt64Type>()
293                                .values()
294                                .iter()
295                                .map(|row_id| partition_map.get(row_id).copied()),
296                        );
297                        let part_ids = UInt32Array::from(part_ids);
298                        batch = batch
299                            .try_with_column(
300                                Field::new(PART_ID_COLUMN, part_ids.data_type().clone(), true),
301                                Arc::new(part_ids.clone()),
302                            )
303                            .expect("failed to add part id column");
304
305                        if part_ids.null_count() > 0 {
306                            info!(
307                                "Filter out rows without valid partition IDs: null_count={}",
308                                part_ids.null_count()
309                            );
310                            let indices = UInt32Array::from_iter(
311                                part_ids
312                                    .iter()
313                                    .enumerate()
314                                    .filter_map(|(idx, v)| v.map(|_| idx as u32)),
315                            );
316                            assert_eq!(indices.len(), batch.num_rows() - part_ids.null_count());
317                            batch = batch.take(&indices)?;
318                        }
319                    }
320                    ivf.transform(&batch)
321                })
322            })
323            .buffer_unordered(get_num_compute_intensive_cpus())
324            .map(|res| match res {
325                Ok(Ok(batch)) => Ok(batch),
326                Ok(Err(err)) => Err(Error::io(err.to_string(), location!())),
327                Err(err) => Err(Error::io(err.to_string(), location!())),
328            })
329            .boxed();
330
331        let start = std::time::Instant::now();
332        shuffler.write_unsorted_stream(stream).await?;
333        info!(
334            "wrote partition assignment to unsorted tmp file in {:?}",
335            start.elapsed()
336        );
337
338        shuffler
339    };
340
341    // step 2: stream in the shuffle data in chunks and write sorted chunks out
342    let start = std::time::Instant::now();
343    let partition_files = shuffler
344        .write_partitioned_shuffles(shuffle_partition_batches, shuffle_partition_concurrency)
345        .await?;
346    info!("created sorted chunks in {:?}", start.elapsed());
347
348    // step 3: load the sorted chunks, consumers are expect to be responsible for merging the streams
349    let start = std::time::Instant::now();
350    let stream =
351        IvfShuffler::load_partitioned_shuffles(&shuffler.output_dir, partition_files).await?;
352    info!("merged partitioned shuffles in {:?}", start.elapsed());
353
354    Ok(stream)
355}
356
357pub async fn shuffle_vectors(
358    unsorted_filenames: Vec<String>,
359    dir_path: Path,
360    ivf_centroids: FixedSizeListArray,
361    shuffle_output_root_filename: &str,
362) -> Result<Vec<String>> {
363    let num_partitions = ivf_centroids.len() as u32;
364    let shuffle_partition_batches = SHUFFLE_BATCH_SIZE * 10;
365    let shuffle_partition_concurrency = 2;
366    let mut shuffler = IvfShuffler::try_new(
367        num_partitions,
368        Some(dir_path),
369        false,
370        Some(shuffle_output_root_filename.to_string()),
371    )?;
372
373    unsafe {
374        shuffler.set_unsorted_buffers(&unsorted_filenames);
375    }
376
377    let partition_files = shuffler
378        .write_partitioned_shuffles(shuffle_partition_batches, shuffle_partition_concurrency)
379        .await?;
380
381    Ok(partition_files)
382}
383
384#[derive(Clone)]
385pub struct IvfShuffler {
386    unsorted_buffers: Vec<String>,
387
388    num_partitions: u32,
389
390    output_dir: Path,
391
392    // whether the lance file is v1 (legacy) or v2
393    is_legacy: bool,
394
395    shuffle_output_root_filename: String,
396}
397
398/// Represents a range of batches in a file that should be shuffled
399struct ShuffleInput {
400    // the idx of the file in IvfShuffler::unsorted_buffers
401    file_idx: usize,
402    // the start index of the batch in the file
403    start: usize,
404    // the end index of the batch in the file
405    end: usize,
406}
407
408impl IvfShuffler {
409    pub fn try_new(
410        num_partitions: u32,
411        output_dir: Option<Path>,
412        is_legacy: bool,
413        shuffle_output_root_filename: Option<String>,
414    ) -> Result<Self> {
415        let output_dir = match output_dir {
416            Some(output_dir) => output_dir,
417            None => get_temp_dir()?,
418        };
419
420        let shuffle_output_root_filename = match shuffle_output_root_filename {
421            Some(shuffle_output_root_filename) => shuffle_output_root_filename,
422            None => "sorted".to_string(),
423        };
424
425        Ok(Self {
426            num_partitions,
427            output_dir,
428            unsorted_buffers: vec![],
429            is_legacy,
430            shuffle_output_root_filename,
431        })
432    }
433
434    /// Set the unsorted buffers to be shuffled.
435    ///
436    /// # Safety
437    ///
438    /// user must ensure the buffers are valid.
439    pub unsafe fn set_unsorted_buffers(&mut self, unsorted_buffers: &[impl ToString]) {
440        self.unsorted_buffers = unsorted_buffers.iter().map(|x| x.to_string()).collect();
441    }
442
443    pub async fn write_unsorted_stream(
444        &mut self,
445        data: impl Stream<Item = Result<RecordBatch>>,
446    ) -> Result<()> {
447        let object_store = ObjectStore::local();
448        let path = self.output_dir.child(UNSORTED_BUFFER);
449        let writer = object_store.create(&path).await?;
450
451        let mut data = Box::pin(data.peekable());
452        let schema = match data.as_mut().peek().await {
453            Some(Ok(batch)) => batch.schema(),
454            Some(Err(err)) => {
455                return Err(Error::io(err.to_string(), location!()));
456            }
457            None => {
458                return Err(Error::io("empty stream".to_string(), location!()));
459            }
460        };
461
462        // validate the schema,
463        // we need to have row ID and partition ID column
464        schema
465            .column_with_name(ROW_ID)
466            .ok_or(Error::io("row ID column not found".to_owned(), location!()))?;
467        schema.column_with_name(PART_ID_COLUMN).ok_or(Error::io(
468            "partition ID column not found".to_owned(),
469            location!(),
470        ))?;
471
472        info!("Writing unsorted data to disk at {}", path);
473        info!("with schema: {:?}", schema);
474
475        let mut file_writer = FileWriter::<ManifestDescribing>::with_object_writer(
476            writer,
477            Schema::try_from(schema.as_ref())?,
478            &Default::default(),
479        )?;
480
481        let mut batches_processed = 0;
482        while let Some(batch) = data.next().await {
483            if batches_processed % 1000 == 0 {
484                info!("Partition assignment progress {}/?", batches_processed);
485            }
486            batches_processed += 1;
487            file_writer.write(&[batch?]).await?;
488        }
489
490        file_writer.finish().await?;
491
492        unsafe {
493            self.set_unsorted_buffers(&[UNSORTED_BUFFER]);
494        }
495
496        Ok(())
497    }
498
499    async fn total_batches(&self) -> Result<Vec<usize>> {
500        let mut total_batches = vec![];
501        for buffer in &self.unsorted_buffers {
502            let object_store = ObjectStore::local();
503            let path = self.output_dir.child(buffer.as_str());
504
505            if self.is_legacy {
506                let reader = FileReader::try_new_self_described(&object_store, &path, None).await?;
507                total_batches.push(reader.num_batches());
508            } else {
509                let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
510                let scheduler = ScanScheduler::new(object_store.into(), scheduler_config);
511                let file = scheduler
512                    .open_file(&path, &CachedFileSize::unknown())
513                    .await?;
514                let cache = LanceCache::with_capacity(128 * 1024 * 1024);
515
516                let reader = Lancev2FileReader::try_open(
517                    file,
518                    None,
519                    Default::default(),
520                    &cache,
521                    FileReaderOptions::default(),
522                )
523                .await?;
524                let num_batches = reader.metadata().num_rows / (SHUFFLE_BATCH_SIZE as u64);
525                total_batches.push(num_batches as usize);
526            }
527        }
528        Ok(total_batches)
529    }
530
531    async fn count_partition_size(&self, inputs: &[ShuffleInput]) -> Result<Vec<u64>> {
532        let object_store = ObjectStore::local();
533        let mut partition_sizes = vec![0; self.num_partitions as usize];
534        let scheduler = ScanScheduler::new(
535            Arc::new(object_store.clone()),
536            SchedulerConfig::max_bandwidth(&object_store),
537        );
538
539        for &ShuffleInput {
540            file_idx,
541            start,
542            end,
543        } in inputs
544        {
545            let file_name = &self.unsorted_buffers[file_idx];
546            let path = self.output_dir.child(file_name.as_str());
547
548            if self.is_legacy {
549                let reader = FileReader::try_new_self_described(&object_store, &path, None).await?;
550                let lance_schema = reader
551                    .schema()
552                    .project(&[PART_ID_COLUMN])
553                    .expect("part id should exist");
554
555                let mut stream = stream::iter(start..end)
556                    .map(|i| reader.read_batch(i as i32, .., &lance_schema))
557                    .buffer_unordered(16);
558
559                while let Some(batch) = stream.next().await {
560                    let batch = batch?;
561                    let part_ids: &UInt32Array = batch
562                        .column_by_name(PART_ID_COLUMN)
563                        .expect("Partition ID column not found")
564                        .as_primitive();
565                    part_ids.values().iter().for_each(|part_id| {
566                        partition_sizes[*part_id as usize] += 1;
567                    });
568                }
569            } else {
570                let file = scheduler
571                    .open_file(&path, &CachedFileSize::unknown())
572                    .await?;
573                let reader = Lancev2FileReader::try_open(
574                    file,
575                    None,
576                    Default::default(),
577                    &LanceCache::no_cache(),
578                    FileReaderOptions::default(),
579                )
580                .await?;
581                let mut stream = reader
582                    .read_stream(
583                        lance_io::ReadBatchParams::Range(
584                            (start * SHUFFLE_BATCH_SIZE)..(end * SHUFFLE_BATCH_SIZE),
585                        ),
586                        SHUFFLE_BATCH_SIZE as u32,
587                        16,
588                        FilterExpression::no_filter(),
589                    )
590                    .unwrap();
591
592                while let Some(batch) = stream.next().await {
593                    let batch = batch?;
594                    let part_ids: &UInt32Array = batch
595                        .column_by_name(PART_ID_COLUMN)
596                        .expect("Partition ID column not found")
597                        .as_primitive();
598                    part_ids.values().iter().for_each(|part_id| {
599                        partition_sizes[*part_id as usize] += 1;
600                    });
601                }
602            }
603        }
604
605        Ok(partition_sizes)
606    }
607
608    async fn shuffle_to_partitions(
609        &self,
610        inputs: &[ShuffleInput],
611        partition_size: Vec<u64>,
612        num_batches_to_sort: usize,
613    ) -> Result<Vec<ListArray>> {
614        info!("Shuffling into memory");
615
616        let mut num_processed = 0;
617        let mut partitions_builder = PartitionListBuilder::new(partition_size);
618
619        for &ShuffleInput {
620            file_idx,
621            start,
622            end,
623        } in inputs
624        {
625            let object_store = ObjectStore::local();
626            let file_name = &self.unsorted_buffers[file_idx];
627            let path = self.output_dir.child(file_name.as_str());
628            let mut _reader_handle = None;
629
630            let mut stream = if self.is_legacy {
631                _reader_handle =
632                    Some(FileReader::try_new_self_described(&object_store, &path, None).await?);
633
634                stream::iter(start..end)
635                    .map(|i| {
636                        let reader = _reader_handle.as_ref().unwrap();
637                        reader.read_batch(i as i32, ReadBatchParams::RangeFull, reader.schema())
638                    })
639                    .buffered(16)
640                    .boxed()
641            } else {
642                let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
643                let scheduler = ScanScheduler::new(Arc::new(object_store), scheduler_config);
644                let file = scheduler
645                    .open_file(&path, &CachedFileSize::unknown())
646                    .await?;
647                let reader = Lancev2FileReader::try_open(
648                    file,
649                    None,
650                    Default::default(),
651                    &LanceCache::no_cache(),
652                    FileReaderOptions::default(),
653                )
654                .await?;
655                reader
656                    .read_stream(
657                        lance_io::ReadBatchParams::Range(
658                            (start * SHUFFLE_BATCH_SIZE)..(end * SHUFFLE_BATCH_SIZE),
659                        ),
660                        SHUFFLE_BATCH_SIZE as u32,
661                        16,
662                        FilterExpression::no_filter(),
663                    )?
664                    .boxed()
665            };
666
667            while let Some(batch) = stream.next().await {
668                if num_processed % 100 == 0 {
669                    info!("Shuffle Progress {}/{}", num_processed, num_batches_to_sort);
670                }
671                num_processed += 1;
672
673                let batch = batch?;
674
675                if batch.num_rows() == 0 {
676                    continue;
677                }
678
679                let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
680                let indices = sort_to_indices(&part_ids, None, None)?;
681                let batch = batch.take(&indices)?;
682
683                let sorted_part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
684
685                let mut start = 0;
686                let mut prev_id = sorted_part_ids.value(0);
687                for (idx, part_id) in sorted_part_ids.values().iter().enumerate() {
688                    if *part_id != prev_id {
689                        partitions_builder.extend(&batch.slice(start, idx - start));
690                        start = idx;
691                        prev_id = *part_id;
692                    }
693                }
694                partitions_builder.extend(&batch.slice(start, sorted_part_ids.len() - start));
695            }
696        }
697
698        partitions_builder.finish()
699    }
700
701    pub async fn write_partitioned_shuffles(
702        &self,
703        batches_per_partition: usize,
704        concurrent_jobs: usize,
705    ) -> Result<Vec<String>> {
706        let num_batches = self.total_batches().await?;
707        let total_batches = num_batches.iter().sum();
708        info!(
709            "Sorting unsorted data into sorted chunks (batches_per_chunk={} concurrent_jobs={})",
710            batches_per_partition, concurrent_jobs
711        );
712        stream::iter((0..total_batches).step_by(batches_per_partition))
713            .zip(stream::repeat(num_batches))
714            .map(|(i, num_batches)| {
715                let this = self.clone();
716                tokio::spawn(async move {
717                    // first, calculate which files and ranges needs to be processed
718                    let start = i;
719                    let end = std::cmp::min(i + batches_per_partition, total_batches);
720                    let num_batches_to_sort = end - start;
721                    let mut input = vec![];
722
723                    let mut cumulative_size = 0;
724                    for (file_idx, partition_size) in num_batches.iter().enumerate() {
725                        let cur_start = cumulative_size;
726                        let cur_end = cumulative_size + partition_size;
727
728                        cumulative_size += partition_size;
729
730                        let should_include_file = start < cur_end && end > cur_start;
731
732                        if !should_include_file {
733                            continue;
734                        }
735
736                        // the current part doesn't overlap with the current batch
737                        if start >= cur_end {
738                            continue;
739                        }
740
741                        let local_start = start.saturating_sub(cur_start);
742                        let local_end = std::cmp::min(end - cur_start, *partition_size);
743
744                        input.push(ShuffleInput {
745                            file_idx,
746                            start: local_start,
747                            end: local_end,
748                        });
749                    }
750
751                    // second, count the number of rows in each partition
752                    let size_counts = this.count_partition_size(&input).await?;
753
754                    // third, shuffle the data into each partition
755                    let shuffled = this
756                        .shuffle_to_partitions(&input, size_counts, num_batches_to_sort)
757                        .await?;
758
759                    // finally, write the shuffled data to disk
760                    let object_store = ObjectStore::local();
761                    let output_file = format!("{}_{}.lance", this.shuffle_output_root_filename, i);
762                    let path = this.output_dir.child(output_file.clone());
763                    let writer = object_store.create(&path).await?;
764
765                    info!(
766                        "Chunk loaded into memory and sorted, writing to disk at {}",
767                        path
768                    );
769
770                    let sorted_file_schema = Arc::new(arrow_schema::Schema::new(vec![Field::new(
771                        "partitions",
772                        shuffled.first().unwrap().data_type().clone(),
773                        true,
774                    )]));
775                    let lance_schema = Schema::try_from(sorted_file_schema.as_ref())?;
776                    let mut file_writer = lance_file::v2::writer::FileWriter::try_new(
777                        writer,
778                        lance_schema,
779                        FileWriterOptions::default(),
780                    )?;
781
782                    for partition_and_idx in shuffled.into_iter().enumerate() {
783                        let (idx, partition) = partition_and_idx;
784                        if idx % 1000 == 0 {
785                            info!("Writing partition {}/{}", idx, this.num_partitions);
786                        }
787                        let batch = RecordBatch::try_new(
788                            sorted_file_schema.clone(),
789                            vec![Arc::new(partition)],
790                        )?;
791                        file_writer.write_batch(&batch).await?;
792                    }
793
794                    file_writer.finish().await?;
795
796                    Ok(output_file) as Result<String>
797                })
798                .map(|join_res| join_res.unwrap())
799            })
800            .buffered(concurrent_jobs)
801            .try_collect()
802            .await
803    }
804
805    pub async fn load_partitioned_shuffles(
806        basedir: &Path,
807        files: Vec<String>,
808    ) -> Result<Vec<impl Stream<Item = Result<RecordBatch>>>> {
809        // impl RecordBatchStream
810        let mut streams = vec![];
811
812        for file in files {
813            let object_store = Arc::new(ObjectStore::local());
814            let path = basedir.child(file);
815            let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
816            let scan_scheduler = ScanScheduler::new(object_store, scheduler_config);
817            let file_scheduler = scan_scheduler
818                .open_file(&path, &CachedFileSize::unknown())
819                .await?;
820            let reader = lance_file::v2::reader::FileReader::try_open(
821                file_scheduler,
822                None,
823                Arc::<DecoderPlugins>::default(),
824                &LanceCache::no_cache(),
825                FileReaderOptions::default(),
826            )
827            .await?;
828            let stream = reader
829                .read_stream(
830                    ReadBatchParams::RangeFull,
831                    /*batch_size=*/ 1,
832                    /*batch_readahead=*/ 32,
833                    FilterExpression::no_filter(),
834                )?
835                .and_then(|batch| {
836                    let list_array = batch
837                        .column(0)
838                        .as_any()
839                        .downcast_ref::<ListArray>()
840                        .expect("ListArray expected");
841                    let struct_array = list_array
842                        .values()
843                        .as_any()
844                        .downcast_ref::<StructArray>()
845                        .expect("StructArray expected")
846                        .clone();
847                    let batch: RecordBatch = struct_array.into();
848                    std::future::ready(Ok(batch))
849                });
850
851            streams.push(stream);
852        }
853
854        Ok(streams)
855    }
856}
857
858#[cfg(test)]
859mod test {
860    use arrow_array::{
861        types::{UInt32Type, UInt8Type},
862        FixedSizeListArray, UInt64Array, UInt8Array,
863    };
864    use arrow_schema::DataType;
865    use lance_arrow::FixedSizeListArrayExt;
866    use lance_core::ROW_ID_FIELD;
867    use lance_io::stream::RecordBatchStreamAdapter;
868    use rand::RngCore;
869
870    use crate::vector::PQ_CODE_COLUMN;
871
872    use super::*;
873
874    fn make_schema(pq_dim: u32) -> Arc<arrow_schema::Schema> {
875        Arc::new(arrow_schema::Schema::new(vec![
876            ROW_ID_FIELD.clone(),
877            arrow_schema::Field::new(PART_ID_COLUMN, DataType::UInt32, true),
878            arrow_schema::Field::new(
879                PQ_CODE_COLUMN,
880                DataType::FixedSizeList(
881                    Arc::new(arrow_schema::Field::new("item", DataType::UInt8, true)),
882                    pq_dim as i32,
883                ),
884                false,
885            ),
886        ]))
887    }
888
889    fn make_stream_and_shuffler(
890        include_empty_batches: bool,
891    ) -> (impl RecordBatchStream, IvfShuffler) {
892        let schema = make_schema(32);
893
894        let schema2 = schema.clone();
895
896        let stream =
897            stream::iter(0..if include_empty_batches { 101 } else { 100 }).map(move |idx| {
898                if include_empty_batches && idx == 100 {
899                    return Ok(RecordBatch::try_new(
900                        schema2.clone(),
901                        vec![
902                            Arc::new(UInt64Array::from_iter_values([])),
903                            Arc::new(UInt32Array::from_iter_values([])),
904                            Arc::new(
905                                FixedSizeListArray::try_new_from_values(
906                                    Arc::new(UInt8Array::from_iter_values([])) as Arc<dyn Array>,
907                                    32,
908                                )
909                                .unwrap(),
910                            ),
911                        ],
912                    )
913                    .unwrap());
914                }
915                let start_idx = idx * (SHUFFLE_BATCH_SIZE as u64);
916                let end_idx = (idx + 1) * (SHUFFLE_BATCH_SIZE as u64);
917                let row_ids = Arc::new(UInt64Array::from_iter(start_idx..end_idx));
918
919                let part_id = Arc::new(UInt32Array::from_iter(
920                    (start_idx..end_idx).map(|_| idx as u32),
921                ));
922
923                let values = Arc::new(UInt8Array::from_iter(
924                    (0..32 * SHUFFLE_BATCH_SIZE).map(|_| idx as u8),
925                ));
926                let pq_codes = Arc::new(
927                    FixedSizeListArray::try_new_from_values(values as Arc<dyn Array>, 32).unwrap(),
928                );
929
930                Ok(
931                    RecordBatch::try_new(schema2.clone(), vec![row_ids, part_id, pq_codes])
932                        .unwrap(),
933                )
934            });
935
936        let stream = RecordBatchStreamAdapter::new(schema, stream);
937
938        let shuffler = IvfShuffler::try_new(100, None, true, None).unwrap();
939
940        (stream, shuffler)
941    }
942
943    fn check_batch(batch: RecordBatch, idx: usize, num_rows: usize) {
944        let row_ids = batch
945            .column_by_name(ROW_ID)
946            .unwrap()
947            .as_primitive::<UInt64Type>();
948        let part_ids = batch
949            .column_by_name(PART_ID_COLUMN)
950            .unwrap()
951            .as_primitive::<UInt32Type>();
952        let pq_codes = batch
953            .column_by_name(PQ_CODE_COLUMN)
954            .unwrap()
955            .as_fixed_size_list()
956            .values()
957            .as_primitive::<UInt8Type>();
958
959        assert_eq!(row_ids.len(), num_rows);
960        assert_eq!(part_ids.len(), num_rows);
961        assert_eq!(pq_codes.len(), num_rows * 32);
962
963        for i in 0..num_rows {
964            assert_eq!(part_ids.value(i), idx as u32);
965        }
966
967        for v in pq_codes.values() {
968            assert_eq!(*v, idx as u8);
969        }
970    }
971
972    #[tokio::test]
973    async fn test_shuffler_single_partition() {
974        let (stream, mut shuffler) = make_stream_and_shuffler(false);
975
976        shuffler.write_unsorted_stream(stream).await.unwrap();
977        let partition_files = shuffler.write_partitioned_shuffles(100, 1).await.unwrap();
978
979        assert_eq!(partition_files.len(), 1);
980
981        let mut result_stream =
982            IvfShuffler::load_partitioned_shuffles(&shuffler.output_dir, partition_files)
983                .await
984                .unwrap();
985
986        let mut num_batches = 0;
987        let mut stream = result_stream.pop().unwrap();
988
989        while let Some(item) = stream.next().await {
990            check_batch(item.unwrap(), num_batches, SHUFFLE_BATCH_SIZE);
991            num_batches += 1;
992        }
993
994        assert_eq!(num_batches, 100);
995    }
996
997    #[tokio::test]
998    async fn test_shuffler_single_partition_with_empty_batch() {
999        let (stream, mut shuffler) = make_stream_and_shuffler(true);
1000
1001        shuffler.write_unsorted_stream(stream).await.unwrap();
1002        let partition_files = shuffler.write_partitioned_shuffles(101, 1).await.unwrap();
1003
1004        assert_eq!(partition_files.len(), 1);
1005
1006        let mut result_stream =
1007            IvfShuffler::load_partitioned_shuffles(&shuffler.output_dir, partition_files)
1008                .await
1009                .unwrap();
1010
1011        let mut num_batches = 0;
1012        let mut stream = result_stream.pop().unwrap();
1013
1014        while let Some(item) = stream.next().await {
1015            check_batch(item.unwrap(), num_batches, SHUFFLE_BATCH_SIZE);
1016            num_batches += 1;
1017        }
1018
1019        assert_eq!(num_batches, 100);
1020    }
1021
1022    #[tokio::test]
1023    async fn test_shuffler_multiple_partition() {
1024        let (stream, mut shuffler) = make_stream_and_shuffler(false);
1025
1026        shuffler.write_unsorted_stream(stream).await.unwrap();
1027        let partition_files = shuffler.write_partitioned_shuffles(1, 100).await.unwrap();
1028
1029        assert_eq!(partition_files.len(), 100);
1030
1031        let mut result_stream =
1032            IvfShuffler::load_partitioned_shuffles(&shuffler.output_dir, partition_files)
1033                .await
1034                .unwrap();
1035
1036        let mut num_batches = 0;
1037        result_stream.reverse();
1038
1039        while let Some(mut stream) = result_stream.pop() {
1040            while let Some(item) = stream.next().await {
1041                check_batch(item.unwrap(), num_batches, SHUFFLE_BATCH_SIZE);
1042                num_batches += 1
1043            }
1044        }
1045
1046        assert_eq!(num_batches, 100);
1047    }
1048
1049    #[tokio::test]
1050    async fn test_shuffler_multi_buffer_single_partition() {
1051        let (stream, mut shuffler) = make_stream_and_shuffler(false);
1052        shuffler.write_unsorted_stream(stream).await.unwrap();
1053
1054        // set the same buffer twice we should get double the data
1055        unsafe { shuffler.set_unsorted_buffers(&[UNSORTED_BUFFER, UNSORTED_BUFFER]) }
1056
1057        let partition_files = shuffler.write_partitioned_shuffles(200, 1).await.unwrap();
1058
1059        assert_eq!(partition_files.len(), 1);
1060
1061        let mut result_stream =
1062            IvfShuffler::load_partitioned_shuffles(&shuffler.output_dir, partition_files)
1063                .await
1064                .unwrap();
1065
1066        let mut num_batches = 0;
1067        result_stream.reverse();
1068
1069        while let Some(mut stream) = result_stream.pop() {
1070            while let Some(item) = stream.next().await {
1071                check_batch(item.unwrap(), num_batches, 2048);
1072                num_batches += 1
1073            }
1074        }
1075
1076        assert_eq!(num_batches, 100);
1077    }
1078
1079    #[tokio::test]
1080    async fn test_shuffler_multi_buffer_multi_partition() {
1081        let (stream, mut shuffler) = make_stream_and_shuffler(false);
1082        shuffler.write_unsorted_stream(stream).await.unwrap();
1083
1084        // set the same buffer twice we should get double the data
1085        unsafe { shuffler.set_unsorted_buffers(&[UNSORTED_BUFFER, UNSORTED_BUFFER]) }
1086
1087        let partition_files = shuffler.write_partitioned_shuffles(1, 32).await.unwrap();
1088        assert_eq!(partition_files.len(), 200);
1089
1090        let mut result_stream =
1091            IvfShuffler::load_partitioned_shuffles(&shuffler.output_dir, partition_files)
1092                .await
1093                .unwrap();
1094
1095        let mut num_batches = 0;
1096        result_stream.reverse();
1097
1098        while let Some(mut stream) = result_stream.pop() {
1099            while let Some(item) = stream.next().await {
1100                check_batch(item.unwrap(), num_batches % 100, SHUFFLE_BATCH_SIZE);
1101                num_batches += 1
1102            }
1103        }
1104
1105        assert_eq!(num_batches, 200);
1106    }
1107
1108    fn make_big_stream_and_shuffler(
1109        num_batches: u32,
1110        num_partitions: u32,
1111        pq_dim: u32,
1112    ) -> (impl RecordBatchStream, IvfShuffler) {
1113        let schema = make_schema(pq_dim);
1114
1115        let schema2 = schema.clone();
1116
1117        let stream = stream::iter(0..num_batches).map(move |idx| {
1118            let mut rng = rand::thread_rng();
1119            let row_ids = Arc::new(UInt64Array::from_iter(
1120                (idx * 1024..(idx + 1) * 1024).map(u64::from),
1121            ));
1122
1123            let part_id = Arc::new(UInt32Array::from_iter(
1124                (idx * 1024..(idx + 1) * 1024).map(|_| rng.next_u32() % num_partitions),
1125            ));
1126
1127            let values = Arc::new(UInt8Array::from_iter((0..pq_dim * 1024).map(|_| idx as u8)));
1128            let pq_codes = Arc::new(
1129                FixedSizeListArray::try_new_from_values(values as Arc<dyn Array>, pq_dim as i32)
1130                    .unwrap(),
1131            );
1132
1133            Ok(RecordBatch::try_new(schema2.clone(), vec![row_ids, part_id, pq_codes]).unwrap())
1134        });
1135
1136        let stream = RecordBatchStreamAdapter::new(schema, stream);
1137
1138        let shuffler = IvfShuffler::try_new(num_partitions, None, true, None).unwrap();
1139
1140        (stream, shuffler)
1141    }
1142
1143    // Change NUM_BATCHES = 1000 * 1024 and NUM_PARTITIONS to 35000 to test 1B shuffle
1144    const NUM_BATCHES: u32 = 100;
1145    const NUM_PARTITIONS: u32 = 1000;
1146    const PQ_DIM: u32 = 48;
1147    const BATCHES_PER_PARTITION: u32 = 10200;
1148    const NUM_CONCURRENT_JOBS: u32 = 16;
1149
1150    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1151    async fn test_big_shuffle() {
1152        let (stream, mut shuffler) =
1153            make_big_stream_and_shuffler(NUM_BATCHES, NUM_PARTITIONS, PQ_DIM);
1154
1155        shuffler.write_unsorted_stream(stream).await.unwrap();
1156        let partition_files = shuffler
1157            .write_partitioned_shuffles(
1158                BATCHES_PER_PARTITION as usize,
1159                NUM_CONCURRENT_JOBS as usize,
1160            )
1161            .await
1162            .unwrap();
1163
1164        let expected_num_part_files = NUM_BATCHES.div_ceil(BATCHES_PER_PARTITION);
1165
1166        assert_eq!(partition_files.len(), expected_num_part_files as usize);
1167
1168        let mut result_stream =
1169            IvfShuffler::load_partitioned_shuffles(&shuffler.output_dir, partition_files)
1170                .await
1171                .unwrap();
1172
1173        let mut num_batches = 0;
1174        result_stream.reverse();
1175
1176        while let Some(mut stream) = result_stream.pop() {
1177            while (stream.next().await).is_some() {
1178                num_batches += 1
1179            }
1180        }
1181
1182        assert_eq!(num_batches, NUM_PARTITIONS * expected_num_part_files);
1183    }
1184}