Skip to main content

lance_index/vector/v3/
shuffler.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Shuffler is a component that takes a stream of record batches and shuffles them into
5//! the corresponding IVF partitions.
6
7use std::ops::Range;
8use std::sync::atomic::AtomicU64;
9use std::sync::{Arc, Mutex};
10
11use arrow::compute::concat_batches;
12use arrow::datatypes::UInt64Type;
13use arrow::{array::AsArray, compute::sort_to_indices};
14use arrow_array::{RecordBatch, UInt32Array, UInt64Array};
15use arrow_schema::{DataType, Field, Schema};
16use futures::{future::try_join_all, prelude::*};
17use lance_arrow::stream::rechunk_stream_by_size;
18use lance_arrow::{RecordBatchExt, SchemaExt};
19use lance_core::{
20    Error, Result,
21    cache::LanceCache,
22    utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu},
23};
24use lance_encoding::decoder::{DecoderPlugins, FilterExpression};
25use lance_encoding::version::LanceFileVersion;
26use lance_file::reader::{FileReader, FileReaderOptions};
27use lance_file::writer::{FileWriter, FileWriterOptions};
28use lance_io::{
29    ReadBatchParams,
30    object_store::ObjectStore,
31    scheduler::{ScanScheduler, SchedulerConfig},
32    stream::{RecordBatchStream, RecordBatchStreamAdapter},
33    utils::CachedFileSize,
34};
35use object_store::path::Path;
36
37use crate::vector::{LOSS_METADATA_KEY, PART_ID_COLUMN};
38
39#[async_trait::async_trait]
40/// A reader that can read the shuffled partitions.
41pub trait ShuffleReader: Send + Sync {
42    /// Read a partition by partition_id
43    /// will return Ok(None) if partition_size is 0
44    /// check reader.partition_size(partition_id) before calling this function
45    async fn read_partition(
46        &self,
47        partition_id: usize,
48    ) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>>;
49
50    /// Get the size of the partition by partition_id
51    fn partition_size(&self, partition_id: usize) -> Result<usize>;
52
53    /// Get the total loss,
54    /// if the loss is not available, return None,
55    /// in such case, the caller should sum up the losses from each batch's metadata.
56    /// Must be called after all partitions are read.
57    fn total_loss(&self) -> Option<f64>;
58}
59
60#[async_trait::async_trait]
61/// A shuffler that can shuffle the incoming stream of record batches into IVF partitions.
62/// Returns a IvfShuffleReader that can be used to read the shuffled partitions.
63pub trait Shuffler: Send + Sync {
64    /// Shuffle the incoming stream of record batches into IVF partitions.
65    /// Returns a IvfShuffleReader that can be used to read the shuffled partitions.
66    async fn shuffle(
67        &self,
68        data: Box<dyn RecordBatchStream + Unpin + 'static>,
69    ) -> Result<Box<dyn ShuffleReader>>;
70}
71
72pub struct IvfShuffler {
73    object_store: Arc<ObjectStore>,
74    output_dir: Path,
75    num_partitions: usize,
76    format_version: LanceFileVersion,
77
78    progress: Arc<dyn crate::progress::IndexBuildProgress>,
79}
80
81impl IvfShuffler {
82    pub fn new(output_dir: Path, num_partitions: usize) -> Self {
83        Self {
84            object_store: Arc::new(ObjectStore::local()),
85            output_dir,
86            num_partitions,
87            format_version: LanceFileVersion::V2_0,
88            progress: crate::progress::noop_progress(),
89        }
90    }
91
92    pub fn with_format_version(mut self, format_version: LanceFileVersion) -> Self {
93        self.format_version = format_version;
94        self
95    }
96
97    pub fn with_progress(mut self, progress: Arc<dyn crate::progress::IndexBuildProgress>) -> Self {
98        self.progress = progress;
99        self
100    }
101}
102
103#[async_trait::async_trait]
104impl Shuffler for IvfShuffler {
105    async fn shuffle(
106        &self,
107        data: Box<dyn RecordBatchStream + Unpin + 'static>,
108    ) -> Result<Box<dyn ShuffleReader>> {
109        let num_partitions = self.num_partitions;
110        let mut partition_sizes = vec![0; num_partitions];
111        let schema = data.schema().without_column(PART_ID_COLUMN);
112        let mut writers = stream::iter(0..num_partitions)
113            .map(|partition_id| {
114                let part_path = self
115                    .output_dir
116                    .clone()
117                    .join(format!("ivf_{}.lance", partition_id));
118                let spill_path = self
119                    .output_dir
120                    .clone()
121                    .join(format!("ivf_{}.spill", partition_id));
122                let object_store = self.object_store.clone();
123                let schema = schema.clone();
124                let format_version = self.format_version;
125                async move {
126                    let writer = object_store.create(&part_path).await?;
127                    let file_writer = FileWriter::try_new(
128                        writer,
129                        lance_core::datatypes::Schema::try_from(&schema)?,
130                        FileWriterOptions {
131                            format_version: Some(format_version),
132                            ..Default::default()
133                        },
134                    )?
135                    .with_page_metadata_spill(object_store.clone(), spill_path);
136                    Result::Ok(file_writer)
137                }
138            })
139            .buffered(self.object_store.io_parallelism())
140            .try_collect::<Vec<_>>()
141            .await?;
142        let mut parallel_sort_stream = data
143            .map(|batch| {
144                spawn_cpu(move || {
145                    let batch = batch?;
146
147                    let loss = batch
148                        .metadata()
149                        .get(LOSS_METADATA_KEY)
150                        .map(|s| s.parse::<f64>().unwrap_or_default())
151                        .unwrap_or_default();
152
153                    let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
154
155                    let indices = sort_to_indices(&part_ids, None, None)?;
156                    let batch = batch.take(&indices)?;
157
158                    let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
159                    let batch = batch.drop_column(PART_ID_COLUMN)?;
160
161                    let mut partition_buffers = vec![Vec::new(); num_partitions];
162
163                    let mut start = 0;
164                    while start < batch.num_rows() {
165                        let part_id: u32 = part_ids.value(start);
166                        let mut end = start + 1;
167                        while end < batch.num_rows() && part_ids.value(end) == part_id {
168                            end += 1;
169                        }
170
171                        let part_batches = &mut partition_buffers[part_id as usize];
172                        part_batches.push(batch.slice(start, end - start));
173                        start = end;
174                    }
175
176                    Ok::<(Vec<Vec<RecordBatch>>, f64), Error>((partition_buffers, loss))
177                })
178            })
179            .buffered(get_num_compute_intensive_cpus());
180
181        let mut total_loss = 0.0;
182        let mut num_rows = 0u64;
183        while let Some(shuffled) = parallel_sort_stream.next().await {
184            let (shuffled, loss) = shuffled?;
185            total_loss += loss;
186
187            let mut futs = Vec::new();
188            for (part_id, (writer, batches)) in writers.iter_mut().zip(shuffled.iter()).enumerate()
189            {
190                if !batches.is_empty() {
191                    let rows = batches.iter().map(|b| b.num_rows()).sum::<usize>();
192                    partition_sizes[part_id] += rows;
193                    num_rows += rows as u64;
194                    futs.push(writer.write_batches(batches.iter()));
195                }
196            }
197            try_join_all(futs).await?;
198
199            self.progress.stage_progress("shuffle", num_rows).await?;
200        }
201
202        // finish all writers
203        for writer in writers.iter_mut() {
204            writer.finish().await?;
205        }
206
207        Ok(Box::new(IvfShufflerReader::new(
208            self.object_store.clone(),
209            self.output_dir.clone(),
210            partition_sizes,
211            total_loss,
212        )))
213    }
214}
215
216pub struct IvfShufflerReader {
217    scheduler: Arc<ScanScheduler>,
218    output_dir: Path,
219    partition_sizes: Vec<usize>,
220    loss: f64,
221}
222
223impl IvfShufflerReader {
224    pub fn new(
225        object_store: Arc<ObjectStore>,
226        output_dir: Path,
227        partition_sizes: Vec<usize>,
228        loss: f64,
229    ) -> Self {
230        let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
231        let scheduler = ScanScheduler::new(object_store, scheduler_config);
232        Self {
233            scheduler,
234            output_dir,
235            partition_sizes,
236            loss,
237        }
238    }
239}
240
241#[async_trait::async_trait]
242impl ShuffleReader for IvfShufflerReader {
243    async fn read_partition(
244        &self,
245        partition_id: usize,
246    ) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>> {
247        if partition_id >= self.partition_sizes.len() {
248            return Ok(None);
249        }
250
251        let partition_path = self
252            .output_dir
253            .clone()
254            .join(format!("ivf_{}.lance", partition_id));
255
256        let reader = FileReader::try_open(
257            self.scheduler
258                .open_file(&partition_path, &CachedFileSize::unknown())
259                .await?,
260            None,
261            Arc::<DecoderPlugins>::default(),
262            &LanceCache::no_cache(),
263            FileReaderOptions::default(),
264        )
265        .await?;
266        let schema: Schema = reader.schema().as_ref().into();
267        let stream = reader
268            .read_stream(
269                lance_io::ReadBatchParams::RangeFull,
270                u32::MAX,
271                16,
272                FilterExpression::no_filter(),
273            )
274            .await?;
275        Ok(Some(Box::new(RecordBatchStreamAdapter::new(
276            Arc::new(schema),
277            stream,
278        ))))
279    }
280
281    fn partition_size(&self, partition_id: usize) -> Result<usize> {
282        Ok(self.partition_sizes.get(partition_id).copied().unwrap_or(0))
283    }
284
285    fn total_loss(&self) -> Option<f64> {
286        Some(self.loss)
287    }
288}
289
290pub struct EmptyReader;
291
292#[async_trait::async_trait]
293impl ShuffleReader for EmptyReader {
294    async fn read_partition(
295        &self,
296        _partition_id: usize,
297    ) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>> {
298        Ok(None)
299    }
300
301    fn partition_size(&self, _partition_id: usize) -> Result<usize> {
302        Ok(0)
303    }
304
305    fn total_loss(&self) -> Option<f64> {
306        None
307    }
308}
309
310/// Create an IVF shuffler. Uses [`TwoFileShuffler`] by default, which writes
311/// all data to just two files (data + offsets) instead of one file per partition.
312/// Set `LANCE_LEGACY_SHUFFLER=1` to fall back to [`IvfShuffler`], which opens
313/// one file per partition.
314///
315/// An optional `progress` callback can be provided to receive shuffle progress
316/// updates.
317pub fn create_ivf_shuffler(
318    output_dir: Path,
319    num_partitions: usize,
320    format_version: LanceFileVersion,
321    progress: Option<Arc<dyn crate::progress::IndexBuildProgress>>,
322) -> Box<dyn Shuffler> {
323    let use_legacy = std::env::var("LANCE_LEGACY_SHUFFLER")
324        .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
325        .unwrap_or(false);
326    if use_legacy {
327        let mut shuffler =
328            IvfShuffler::new(output_dir, num_partitions).with_format_version(format_version);
329        if let Some(progress) = progress {
330            shuffler = shuffler.with_progress(progress);
331        }
332        Box::new(shuffler)
333    } else {
334        let mut shuffler = TwoFileShuffler::new(output_dir, num_partitions);
335        if let Some(progress) = progress {
336            shuffler = shuffler.with_progress(progress);
337        }
338        Box::new(shuffler)
339    }
340}
341
342const DEFAULT_SHUFFLE_BATCH_BYTES: usize = 128 * 1024 * 1024;
343
344/// Limit of how much transformed data we accumulate before spilling to disk.
345///
346/// A larger value will use more RAM but require less random access during the
347/// read phase.
348///
349/// This default is likely to be fine for most use cases.
350fn shuffle_batch_bytes() -> usize {
351    let batch_size = std::env::var("LANCE_SHUFFLE_BATCH_BYTES")
352        .ok()
353        .and_then(|s| s.parse().ok())
354        .unwrap_or(DEFAULT_SHUFFLE_BATCH_BYTES);
355    if batch_size == 0 {
356        log::warn!(
357            "LANCE_SHUFFLE_BATCH_BYTES is 0, using default of {}",
358            DEFAULT_SHUFFLE_BATCH_BYTES
359        );
360        DEFAULT_SHUFFLE_BATCH_BYTES
361    } else {
362        batch_size
363    }
364}
365
366/// A shuffler that writes all data to just two files (data + offsets) instead
367/// of one file per partition. This avoids hitting OS file descriptor limits
368/// when there are many partitions.
369///
370/// First we accumulate data in memory until we reach the batch size limit.
371/// Then we sort the data by partition ID and compute an offset per partition.
372/// Then we write the data to a data file and the offsets to an offsets file.
373///
374/// To read the data back, we read every Nth value from the offsets file to get
375/// the start and end of each partition.
376///
377/// Then we read those ranges from the data file.
378pub struct TwoFileShuffler {
379    object_store: Arc<ObjectStore>,
380    output_dir: Path,
381    num_partitions: usize,
382    batch_size_bytes: usize,
383
384    progress: Arc<dyn crate::progress::IndexBuildProgress>,
385}
386
387impl TwoFileShuffler {
388    pub fn new(output_dir: Path, num_partitions: usize) -> Self {
389        Self {
390            object_store: Arc::new(ObjectStore::local()),
391            output_dir,
392            num_partitions,
393            batch_size_bytes: shuffle_batch_bytes(),
394            progress: crate::progress::noop_progress(),
395        }
396    }
397
398    pub fn with_progress(mut self, progress: Arc<dyn crate::progress::IndexBuildProgress>) -> Self {
399        self.progress = progress;
400        self
401    }
402
403    #[cfg(test)]
404    fn with_batch_size_bytes(mut self, batch_size_bytes: usize) -> Self {
405        self.batch_size_bytes = batch_size_bytes;
406        self
407    }
408}
409
410#[async_trait::async_trait]
411impl Shuffler for TwoFileShuffler {
412    async fn shuffle(
413        &self,
414        data: Box<dyn RecordBatchStream + Unpin + 'static>,
415    ) -> Result<Box<dyn ShuffleReader>> {
416        let num_partitions = self.num_partitions;
417        let full_schema = Arc::new(data.schema().as_ref().clone());
418        // No need to write partition ids since we can infer this
419        let schema = data.schema().without_column(PART_ID_COLUMN);
420        let offsets_schema = Arc::new(Schema::new(vec![Field::new(
421            "offset",
422            DataType::UInt64,
423            false,
424        )]));
425        let batch_size_bytes = self.batch_size_bytes;
426
427        // Extract loss from batch metadata before rechunking (concat_batches drops metadata)
428        let total_loss = Arc::new(Mutex::new(0.0f64));
429        let loss_ref = total_loss.clone();
430        let loss_stream = data.map(move |result| {
431            result.inspect(|batch| {
432                let loss = batch
433                    .metadata()
434                    .get(LOSS_METADATA_KEY)
435                    .and_then(|s| s.parse::<f64>().ok())
436                    .unwrap_or(0.0);
437                *loss_ref.lock().unwrap() += loss;
438            })
439        });
440
441        // Rechunk to target batch size
442        let rechunked = rechunk_stream_by_size(
443            loss_stream,
444            full_schema,
445            batch_size_bytes,
446            batch_size_bytes * 2,
447        );
448
449        // Create data file writer
450        let data_path = self.output_dir.clone().join("shuffle_data.lance");
451        let spill_path = self.output_dir.clone().join("shuffle_data.spill");
452        let writer = self.object_store.create(&data_path).await?;
453        let mut file_writer = FileWriter::try_new(
454            writer,
455            lance_core::datatypes::Schema::try_from(&schema)?,
456            Default::default(),
457        )?
458        .with_page_metadata_spill(self.object_store.clone(), spill_path);
459
460        // Create offsets file writer
461        let offsets_path = self.output_dir.clone().join("shuffle_offsets.lance");
462        let spill_path = self.output_dir.clone().join("shuffle_offsets.spill");
463        let writer = self.object_store.create(&offsets_path).await?;
464        let mut offsets_writer = FileWriter::try_new(
465            writer,
466            lance_core::datatypes::Schema::try_from(offsets_schema.as_ref())?,
467            Default::default(),
468        )?
469        .with_page_metadata_spill(self.object_store.clone(), spill_path);
470
471        let num_batches = Arc::new(AtomicU64::new(0));
472        let num_batches_ref = num_batches.clone();
473        let mut partition_counts: Vec<u64> = vec![0; num_partitions];
474        let mut global_row_count: u64 = 0;
475        let mut rows_processed: u64 = 0;
476
477        let mut rechunked = std::pin::pin!(rechunked);
478        while let Some(batch) = rechunked.next().await {
479            num_batches_ref.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
480            let batch = batch?;
481            let np = num_partitions;
482            let num_rows = batch.num_rows() as u64;
483
484            // Sort by partition ID and compute offsets on CPU
485            let (sorted_batch, batch_offsets) = spawn_cpu(move || {
486                let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
487                let indices = sort_to_indices(part_ids, None, None)?;
488                let batch = batch.take(&indices)?;
489
490                let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
491                let batch = batch.drop_column(PART_ID_COLUMN)?;
492
493                // Count rows per partition by scanning sorted part IDs
494                let mut partition_counts = vec![0u64; np];
495                for i in 0..part_ids.len() {
496                    let pid = part_ids.value(i) as usize;
497                    if pid < np {
498                        partition_counts[pid] += 1;
499                    } else {
500                        log::warn!("Partition ID {} is out of range [0, {})", pid, np);
501                    }
502                }
503
504                // Build cumulative offsets (end positions) for this batch
505                let mut batch_offsets = Vec::with_capacity(np);
506                let mut running = 0u64;
507                for count in &partition_counts {
508                    running += count;
509                    batch_offsets.push(running);
510                }
511
512                Ok::<(RecordBatch, Vec<u64>), Error>((batch, batch_offsets))
513            })
514            .await?;
515
516            // Write sorted batch to data file
517            file_writer.write_batch(&sorted_batch).await?;
518
519            // Record offsets adjusted by global row count
520            let mut adjusted_offsets = Vec::with_capacity(batch_offsets.len());
521            let mut last_offset = 0;
522            for (idx, offset) in batch_offsets.iter().enumerate() {
523                adjusted_offsets.push(global_row_count + offset);
524                partition_counts[idx] += offset - last_offset;
525                last_offset = *offset;
526            }
527            global_row_count += sorted_batch.num_rows() as u64;
528
529            // Write offsets to offsets file
530            let offsets_batch = RecordBatch::try_new(
531                offsets_schema.clone(),
532                vec![Arc::new(UInt64Array::from(adjusted_offsets))],
533            )?;
534            offsets_writer.write_batch(&offsets_batch).await?;
535
536            rows_processed += num_rows;
537            self.progress
538                .stage_progress("shuffle", rows_processed)
539                .await?;
540        }
541
542        // Finish files
543        file_writer.finish().await?;
544        offsets_writer.finish().await?;
545
546        let num_batches = num_batches.load(std::sync::atomic::Ordering::Relaxed);
547
548        let total_loss_val = *total_loss.lock().unwrap();
549
550        TwoFileShuffleReader::try_new(
551            self.object_store.clone(),
552            self.output_dir.clone(),
553            num_partitions,
554            num_batches,
555            partition_counts,
556            total_loss_val,
557        )
558        .await
559    }
560}
561
562pub struct TwoFileShuffleReader {
563    _scheduler: Arc<ScanScheduler>,
564    file_reader: FileReader,
565    offsets_reader: FileReader,
566    num_partitions: usize,
567    num_batches: u64,
568    partition_counts: Vec<u64>,
569    total_loss: f64,
570}
571
572impl TwoFileShuffleReader {
573    async fn try_new(
574        object_store: Arc<ObjectStore>,
575        output_dir: Path,
576        num_partitions: usize,
577        num_batches: u64,
578        partition_counts: Vec<u64>,
579        total_loss: f64,
580    ) -> Result<Box<dyn ShuffleReader>> {
581        if num_batches == 0 {
582            return Ok(Box::new(EmptyReader));
583        }
584
585        let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
586        let scheduler = ScanScheduler::new(object_store, scheduler_config);
587
588        let data_path = output_dir.clone().join("shuffle_data.lance");
589        let file_reader = FileReader::try_open(
590            scheduler
591                .open_file(&data_path, &CachedFileSize::unknown())
592                .await?,
593            None,
594            Arc::<DecoderPlugins>::default(),
595            &LanceCache::no_cache(),
596            FileReaderOptions::default(),
597        )
598        .await?;
599
600        let offsets_path = output_dir.clone().join("shuffle_offsets.lance");
601        let offsets_reader = FileReader::try_open(
602            scheduler
603                .open_file(&offsets_path, &CachedFileSize::unknown())
604                .await?,
605            None,
606            Arc::<DecoderPlugins>::default(),
607            &LanceCache::no_cache(),
608            FileReaderOptions::default(),
609        )
610        .await?;
611
612        Ok(Box::new(Self {
613            _scheduler: scheduler,
614            file_reader,
615            offsets_reader,
616            num_partitions,
617            num_batches,
618            partition_counts,
619            total_loss,
620        }))
621    }
622
623    async fn partition_ranges(&self, partition_id: usize) -> Result<Vec<Range<u64>>> {
624        let mut positions = Vec::with_capacity(self.num_batches as usize * 2);
625        for batch_idx in 0..self.num_batches {
626            let end_pos = u32::try_from(batch_idx as usize * self.num_partitions + partition_id)
627                .map_err(|_| Error::invalid_input("There are more than 2^32 partition offsets in the spill file.  Need to support 64-bit take"))?;
628            if end_pos != 0 {
629                positions.push(end_pos - 1);
630            }
631            positions.push(end_pos);
632        }
633        let positions = UInt32Array::from(positions);
634        let num_positions = positions.len() as u32;
635        let offsets_stream = self
636            .offsets_reader
637            .read_stream(
638                ReadBatchParams::Indices(positions),
639                num_positions,
640                1,
641                FilterExpression::no_filter(),
642            )
643            .await?;
644        let schema = offsets_stream.schema().clone();
645        let offsets = offsets_stream.try_collect::<Vec<_>>().await?;
646        let offsets = if offsets.is_empty() {
647            // We should not hit this path if there is no batches
648            unreachable!()
649        } else if offsets.len() == 1 {
650            offsets.into_iter().next().unwrap()
651        } else {
652            concat_batches(&schema, &offsets)?
653        };
654
655        let offsets = offsets.column(0).as_primitive::<UInt64Type>();
656        let mut offsets_iter = offsets.values().iter().copied();
657
658        let mut ranges = Vec::with_capacity(self.num_batches as usize);
659        for batch_idx in 0..self.num_batches {
660            if batch_idx == 0 && partition_id == 0 {
661                // Implicit 0 for start-of-file
662                ranges.push(0..offsets_iter.next().unwrap());
663            } else {
664                ranges.push(offsets_iter.next().unwrap()..offsets_iter.next().unwrap());
665            }
666        }
667        Ok(ranges)
668    }
669}
670
671#[async_trait::async_trait]
672impl ShuffleReader for TwoFileShuffleReader {
673    async fn read_partition(
674        &self,
675        partition_id: usize,
676    ) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>> {
677        if partition_id >= self.num_partitions {
678            return Ok(None);
679        }
680        if self.partition_counts[partition_id] == 0 {
681            return Ok(None);
682        }
683
684        let ranges = self.partition_ranges(partition_id).await?;
685        if ranges.is_empty() {
686            return Ok(None);
687        }
688
689        let schema: Schema = self.file_reader.schema().as_ref().into();
690        let stream = self
691            .file_reader
692            .read_stream(
693                ReadBatchParams::Ranges(ranges.into()),
694                u32::MAX,
695                16,
696                FilterExpression::no_filter(),
697            )
698            .await?;
699        Ok(Some(Box::new(RecordBatchStreamAdapter::new(
700            Arc::new(schema),
701            stream,
702        ))))
703    }
704
705    fn partition_size(&self, partition_id: usize) -> Result<usize> {
706        Ok(self
707            .partition_counts
708            .get(partition_id)
709            .copied()
710            .unwrap_or(0) as usize)
711    }
712
713    fn total_loss(&self) -> Option<f64> {
714        Some(self.total_loss)
715    }
716}
717
718#[cfg(test)]
719mod tests {
720    use super::*;
721
722    use arrow_array::{Int32Array, RecordBatch, UInt32Array};
723    use arrow_schema::{DataType, Field, Schema as ArrowSchema};
724    use futures::stream;
725    use lance_arrow::RecordBatchExt;
726    use lance_core::utils::tempfile::TempStrDir;
727    use lance_io::stream::RecordBatchStreamAdapter;
728
729    use crate::vector::{LOSS_METADATA_KEY, PART_ID_COLUMN};
730
731    /// Create a test batch with partition IDs, an int column, and optional loss metadata.
732    fn make_batch(part_ids: &[u32], values: &[i32], loss: Option<f64>) -> RecordBatch {
733        let schema = Arc::new(ArrowSchema::new(vec![
734            Field::new(PART_ID_COLUMN, DataType::UInt32, false),
735            Field::new("val", DataType::Int32, false),
736        ]));
737        let batch = RecordBatch::try_new(
738            schema,
739            vec![
740                Arc::new(UInt32Array::from(part_ids.to_vec())),
741                Arc::new(Int32Array::from(values.to_vec())),
742            ],
743        )
744        .unwrap();
745        if let Some(loss_val) = loss {
746            batch
747                .add_metadata(LOSS_METADATA_KEY.to_owned(), loss_val.to_string())
748                .unwrap()
749        } else {
750            batch
751        }
752    }
753
754    fn batches_to_stream(
755        batches: Vec<RecordBatch>,
756    ) -> Box<dyn RecordBatchStream + Unpin + 'static> {
757        let schema = batches[0].schema();
758        let stream = stream::iter(batches.into_iter().map(Ok));
759        Box::new(RecordBatchStreamAdapter::new(schema, stream))
760    }
761
762    /// Collect all rows from a partition into a single RecordBatch.
763    async fn collect_partition(
764        reader: &dyn ShuffleReader,
765        partition_id: usize,
766    ) -> Option<RecordBatch> {
767        let stream = reader.read_partition(partition_id).await.unwrap()?;
768        let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
769        if batches.is_empty() {
770            return None;
771        }
772        Some(arrow::compute::concat_batches(&batches[0].schema(), &batches).unwrap())
773    }
774
775    #[tokio::test]
776    async fn test_two_file_shuffler_round_trip() {
777        let dir = TempStrDir::default();
778        let output_dir = Path::from(dir.as_ref());
779        let num_partitions = 3;
780
781        // Partition 0: rows with values 10, 40
782        // Partition 1: rows with values 20, 50
783        // Partition 2: rows with values 30
784        let batch = make_batch(&[0, 1, 2, 0, 1], &[10, 20, 30, 40, 50], None);
785
786        let shuffler = TwoFileShuffler::new(output_dir, num_partitions);
787        let stream = batches_to_stream(vec![batch]);
788        let reader = shuffler.shuffle(stream).await.unwrap();
789
790        // Verify partition sizes
791        assert_eq!(reader.partition_size(0).unwrap(), 2);
792        assert_eq!(reader.partition_size(1).unwrap(), 2);
793        assert_eq!(reader.partition_size(2).unwrap(), 1);
794
795        // Verify partition 0 data
796        let p0 = collect_partition(reader.as_ref(), 0).await.unwrap();
797        let vals: &Int32Array = p0.column_by_name("val").unwrap().as_primitive();
798        let mut v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
799        v.sort();
800        assert_eq!(v, vec![10, 40]);
801
802        // Verify partition 1 data
803        let p1 = collect_partition(reader.as_ref(), 1).await.unwrap();
804        let vals: &Int32Array = p1.column_by_name("val").unwrap().as_primitive();
805        let mut v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
806        v.sort();
807        assert_eq!(v, vec![20, 50]);
808
809        // Verify partition 2 data
810        let p2 = collect_partition(reader.as_ref(), 2).await.unwrap();
811        let vals: &Int32Array = p2.column_by_name("val").unwrap().as_primitive();
812        let v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
813        assert_eq!(v, vec![30]);
814
815        // Out of range partition returns None
816        assert!(reader.read_partition(3).await.unwrap().is_none());
817    }
818
819    #[tokio::test]
820    async fn test_two_file_shuffler_empty_partitions() {
821        let dir = TempStrDir::default();
822        let output_dir = Path::from(dir.as_ref());
823        let num_partitions = 5;
824
825        // Only use partitions 0 and 3, leaving 1, 2, 4 empty
826        let batch = make_batch(&[0, 3, 0, 3], &[10, 20, 30, 40], None);
827
828        let shuffler = TwoFileShuffler::new(output_dir, num_partitions);
829        let stream = batches_to_stream(vec![batch]);
830        let reader = shuffler.shuffle(stream).await.unwrap();
831
832        assert_eq!(reader.partition_size(0).unwrap(), 2);
833        assert_eq!(reader.partition_size(1).unwrap(), 0);
834        assert_eq!(reader.partition_size(2).unwrap(), 0);
835        assert_eq!(reader.partition_size(3).unwrap(), 2);
836        assert_eq!(reader.partition_size(4).unwrap(), 0);
837
838        assert!(reader.read_partition(1).await.unwrap().is_none());
839        assert!(reader.read_partition(2).await.unwrap().is_none());
840        assert!(reader.read_partition(4).await.unwrap().is_none());
841
842        let p0 = collect_partition(reader.as_ref(), 0).await.unwrap();
843        assert_eq!(p0.num_rows(), 2);
844        let p3 = collect_partition(reader.as_ref(), 3).await.unwrap();
845        assert_eq!(p3.num_rows(), 2);
846    }
847
848    #[tokio::test]
849    async fn test_two_file_shuffler_loss_tracking() {
850        let dir = TempStrDir::default();
851        let output_dir = Path::from(dir.as_ref());
852        let num_partitions = 2;
853
854        let batch1 = make_batch(&[0, 1], &[10, 20], Some(1.5));
855        let batch2 = make_batch(&[0, 1], &[30, 40], Some(2.5));
856        let batch3 = make_batch(&[0], &[50], Some(0.25));
857
858        let shuffler = TwoFileShuffler::new(output_dir, num_partitions);
859        let stream = batches_to_stream(vec![batch1, batch2, batch3]);
860        let reader = shuffler.shuffle(stream).await.unwrap();
861
862        let loss = reader.total_loss().unwrap();
863        assert!((loss - 4.25).abs() < 1e-10, "expected 4.25, got {}", loss);
864    }
865
866    #[tokio::test]
867    async fn test_two_file_shuffler_single_batch() {
868        let dir = TempStrDir::default();
869        let output_dir = Path::from(dir.as_ref());
870        let num_partitions = 2;
871
872        let batch = make_batch(&[1, 0], &[100, 200], Some(3.0));
873
874        let shuffler = TwoFileShuffler::new(output_dir, num_partitions);
875        let stream = batches_to_stream(vec![batch]);
876        let reader = shuffler.shuffle(stream).await.unwrap();
877
878        assert_eq!(reader.partition_size(0).unwrap(), 1);
879        assert_eq!(reader.partition_size(1).unwrap(), 1);
880
881        let p0 = collect_partition(reader.as_ref(), 0).await.unwrap();
882        let vals: &Int32Array = p0.column_by_name("val").unwrap().as_primitive();
883        assert_eq!(vals.value(0), 200);
884
885        let p1 = collect_partition(reader.as_ref(), 1).await.unwrap();
886        let vals: &Int32Array = p1.column_by_name("val").unwrap().as_primitive();
887        assert_eq!(vals.value(0), 100);
888
889        assert!((reader.total_loss().unwrap() - 3.0).abs() < 1e-10);
890    }
891
892    #[tokio::test]
893    async fn test_two_file_shuffler_multiple_batches() {
894        let dir = TempStrDir::default();
895        let output_dir = Path::from(dir.as_ref());
896        let num_partitions = 3;
897
898        // Use a very small batch size to force multiple write batches
899        // Each i32 is 4 bytes, each u32 is 4 bytes, so ~8 bytes/row.
900        // With a small batch_size_bytes, we get multiple rechunked batches.
901        let batch1 = make_batch(&[0, 1, 2], &[10, 20, 30], Some(1.0));
902        let batch2 = make_batch(&[2, 0, 1], &[40, 50, 60], Some(2.0));
903        let batch3 = make_batch(&[1, 2, 0], &[70, 80, 90], Some(3.0));
904
905        let shuffler = TwoFileShuffler::new(output_dir, num_partitions)
906            // Set very small batch size to force multiple batches
907            .with_batch_size_bytes(16);
908        let stream = batches_to_stream(vec![batch1, batch2, batch3]);
909        let reader = shuffler.shuffle(stream).await.unwrap();
910
911        // Partition 0 should have values: 10, 50, 90
912        assert_eq!(reader.partition_size(0).unwrap(), 3);
913        let p0 = collect_partition(reader.as_ref(), 0).await.unwrap();
914        let vals: &Int32Array = p0.column_by_name("val").unwrap().as_primitive();
915        let mut v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
916        v.sort();
917        assert_eq!(v, vec![10, 50, 90]);
918
919        // Partition 1 should have values: 20, 60, 70
920        assert_eq!(reader.partition_size(1).unwrap(), 3);
921        let p1 = collect_partition(reader.as_ref(), 1).await.unwrap();
922        let vals: &Int32Array = p1.column_by_name("val").unwrap().as_primitive();
923        let mut v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
924        v.sort();
925        assert_eq!(v, vec![20, 60, 70]);
926
927        // Partition 2 should have values: 30, 40, 80
928        assert_eq!(reader.partition_size(2).unwrap(), 3);
929        let p2 = collect_partition(reader.as_ref(), 2).await.unwrap();
930        let vals: &Int32Array = p2.column_by_name("val").unwrap().as_primitive();
931        let mut v: Vec<i32> = vals.iter().map(|x| x.unwrap()).collect();
932        v.sort();
933        assert_eq!(v, vec![30, 40, 80]);
934
935        assert!((reader.total_loss().unwrap() - 6.0).abs() < 1e-10);
936    }
937}