Skip to main content

lance_index/vector/v3/
shuffler.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Shuffler is a component that takes a stream of record batches and shuffles them into
5//! the corresponding IVF partitions.
6
7use std::sync::Arc;
8
9use arrow::{array::AsArray, compute::sort_to_indices};
10use arrow_array::{RecordBatch, UInt32Array};
11use arrow_schema::Schema;
12use future::try_join_all;
13use futures::prelude::*;
14use lance_arrow::{RecordBatchExt, SchemaExt};
15use lance_core::{
16    cache::LanceCache,
17    utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu},
18    Error, Result,
19};
20use lance_encoding::decoder::{DecoderPlugins, FilterExpression};
21use lance_file::reader::{FileReader, FileReaderOptions};
22use lance_file::writer::FileWriter;
23use lance_io::{
24    object_store::ObjectStore,
25    scheduler::{ScanScheduler, SchedulerConfig},
26    stream::{RecordBatchStream, RecordBatchStreamAdapter},
27    utils::CachedFileSize,
28};
29use object_store::path::Path;
30
31use crate::vector::{LOSS_METADATA_KEY, PART_ID_COLUMN};
32
33#[async_trait::async_trait]
34/// A reader that can read the shuffled partitions.
35pub trait ShuffleReader: Send + Sync {
36    /// Read a partition by partition_id
37    /// will return Ok(None) if partition_size is 0
38    /// check reader.partition_size(partition_id) before calling this function
39    async fn read_partition(
40        &self,
41        partition_id: usize,
42    ) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>>;
43
44    /// Get the size of the partition by partition_id
45    fn partition_size(&self, partition_id: usize) -> Result<usize>;
46
47    /// Get the total loss,
48    /// if the loss is not available, return None,
49    /// in such case, the caller should sum up the losses from each batch's metadata.
50    /// Must be called after all partitions are read.
51    fn total_loss(&self) -> Option<f64>;
52}
53
54#[async_trait::async_trait]
55/// A shuffler that can shuffle the incoming stream of record batches into IVF partitions.
56/// Returns a IvfShuffleReader that can be used to read the shuffled partitions.
57pub trait Shuffler: Send + Sync {
58    /// Shuffle the incoming stream of record batches into IVF partitions.
59    /// Returns a IvfShuffleReader that can be used to read the shuffled partitions.
60    async fn shuffle(
61        &self,
62        data: Box<dyn RecordBatchStream + Unpin + 'static>,
63    ) -> Result<Box<dyn ShuffleReader>>;
64}
65
66pub struct IvfShuffler {
67    object_store: Arc<ObjectStore>,
68    output_dir: Path,
69    num_partitions: usize,
70
71    // options
72    buffer_size: usize,
73    precomputed_shuffle_buffers: Option<Vec<String>>,
74}
75
76impl IvfShuffler {
77    pub fn new(output_dir: Path, num_partitions: usize) -> Self {
78        Self {
79            object_store: Arc::new(ObjectStore::local()),
80            output_dir,
81            num_partitions,
82            buffer_size: 4096,
83            precomputed_shuffle_buffers: None,
84        }
85    }
86
87    pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
88        self.buffer_size = buffer_size;
89        self
90    }
91
92    pub fn with_precomputed_shuffle_buffers(
93        mut self,
94        precomputed_shuffle_buffers: Option<Vec<String>>,
95    ) -> Self {
96        self.precomputed_shuffle_buffers = precomputed_shuffle_buffers;
97        self
98    }
99}
100
101#[async_trait::async_trait]
102impl Shuffler for IvfShuffler {
103    async fn shuffle(
104        &self,
105        data: Box<dyn RecordBatchStream + Unpin + 'static>,
106    ) -> Result<Box<dyn ShuffleReader>> {
107        let num_partitions = self.num_partitions;
108        let mut partition_sizes = vec![0; num_partitions];
109        let schema = data.schema().without_column(PART_ID_COLUMN);
110        let mut writers = stream::iter(0..num_partitions)
111            .map(|partition_id| {
112                let part_path = self.output_dir.child(format!("ivf_{}.lance", partition_id));
113                let object_store = self.object_store.clone();
114                let schema = schema.clone();
115                async move {
116                    let writer = object_store.create(&part_path).await?;
117                    FileWriter::try_new(
118                        writer,
119                        lance_core::datatypes::Schema::try_from(&schema)?,
120                        Default::default(),
121                    )
122                }
123            })
124            .buffered(self.object_store.io_parallelism())
125            .try_collect::<Vec<_>>()
126            .await?;
127        let mut parallel_sort_stream = data
128            .map(|batch| {
129                spawn_cpu(move || {
130                    let batch = batch?;
131
132                    let loss = batch
133                        .metadata()
134                        .get(LOSS_METADATA_KEY)
135                        .map(|s| s.parse::<f64>().unwrap_or_default())
136                        .unwrap_or_default();
137
138                    let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
139
140                    let indices = sort_to_indices(&part_ids, None, None)?;
141                    let batch = batch.take(&indices)?;
142
143                    let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
144                    let batch = batch.drop_column(PART_ID_COLUMN)?;
145
146                    let mut partition_buffers = vec![Vec::new(); num_partitions];
147
148                    let mut start = 0;
149                    while start < batch.num_rows() {
150                        let part_id: u32 = part_ids.value(start);
151                        let mut end = start + 1;
152                        while end < batch.num_rows() && part_ids.value(end) == part_id {
153                            end += 1;
154                        }
155
156                        let part_batches = &mut partition_buffers[part_id as usize];
157                        part_batches.push(batch.slice(start, end - start));
158                        start = end;
159                    }
160
161                    Ok::<(Vec<Vec<RecordBatch>>, f64), Error>((partition_buffers, loss))
162                })
163            })
164            .buffered(get_num_compute_intensive_cpus());
165
166        // part_id:           |       0        |       1        |       3        |
167        // partition_buffers: |[batch,batch,..]|[batch,batch,..]|[batch,batch,..]|
168        let mut partition_buffers = vec![Vec::new(); num_partitions];
169
170        let mut counter = 0;
171        let mut total_loss = 0.0;
172        while let Some(shuffled) = parallel_sort_stream.next().await {
173            let (shuffled, loss) = shuffled?;
174            total_loss += loss;
175
176            for (part_id, batches) in shuffled.into_iter().enumerate() {
177                let part_batches = &mut partition_buffers[part_id];
178                part_batches.extend(batches);
179            }
180
181            counter += 1;
182
183            // do flush
184            if counter % self.buffer_size == 0 {
185                let mut futs = vec![];
186                for (part_id, writer) in writers.iter_mut().enumerate() {
187                    let batches = &partition_buffers[part_id];
188                    partition_sizes[part_id] += batches.iter().map(|b| b.num_rows()).sum::<usize>();
189                    futs.push(writer.write_batches(batches.iter()));
190                }
191                try_join_all(futs).await?;
192
193                partition_buffers.iter_mut().for_each(|b| b.clear());
194            }
195        }
196
197        // final flush
198        for (part_id, batches) in partition_buffers.into_iter().enumerate() {
199            let writer = &mut writers[part_id];
200            partition_sizes[part_id] += batches.iter().map(|b| b.num_rows()).sum::<usize>();
201            for batch in batches.iter() {
202                writer.write_batch(batch).await?;
203            }
204        }
205
206        // finish all writers
207        for writer in writers.iter_mut() {
208            writer.finish().await?;
209        }
210
211        Ok(Box::new(IvfShufflerReader::new(
212            self.object_store.clone(),
213            self.output_dir.clone(),
214            partition_sizes,
215            total_loss,
216        )))
217    }
218}
219
220pub struct IvfShufflerReader {
221    scheduler: Arc<ScanScheduler>,
222    output_dir: Path,
223    partition_sizes: Vec<usize>,
224    loss: f64,
225}
226
227impl IvfShufflerReader {
228    pub fn new(
229        object_store: Arc<ObjectStore>,
230        output_dir: Path,
231        partition_sizes: Vec<usize>,
232        loss: f64,
233    ) -> Self {
234        let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
235        let scheduler = ScanScheduler::new(object_store, scheduler_config);
236        Self {
237            scheduler,
238            output_dir,
239            partition_sizes,
240            loss,
241        }
242    }
243}
244
245#[async_trait::async_trait]
246impl ShuffleReader for IvfShufflerReader {
247    async fn read_partition(
248        &self,
249        partition_id: usize,
250    ) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>> {
251        if partition_id >= self.partition_sizes.len() {
252            return Ok(None);
253        }
254
255        let partition_path = self.output_dir.child(format!("ivf_{}.lance", partition_id));
256
257        let reader = FileReader::try_open(
258            self.scheduler
259                .open_file(&partition_path, &CachedFileSize::unknown())
260                .await?,
261            None,
262            Arc::<DecoderPlugins>::default(),
263            &LanceCache::no_cache(),
264            FileReaderOptions::default(),
265        )
266        .await?;
267        let schema: Schema = reader.schema().as_ref().into();
268        Ok(Some(Box::new(RecordBatchStreamAdapter::new(
269            Arc::new(schema),
270            reader.read_stream(
271                lance_io::ReadBatchParams::RangeFull,
272                u32::MAX,
273                16,
274                FilterExpression::no_filter(),
275            )?,
276        ))))
277    }
278
279    fn partition_size(&self, partition_id: usize) -> Result<usize> {
280        Ok(self.partition_sizes.get(partition_id).copied().unwrap_or(0))
281    }
282
283    fn total_loss(&self) -> Option<f64> {
284        Some(self.loss)
285    }
286}
287
288pub struct EmptyReader;
289
290#[async_trait::async_trait]
291impl ShuffleReader for EmptyReader {
292    async fn read_partition(
293        &self,
294        _partition_id: usize,
295    ) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>> {
296        Ok(None)
297    }
298
299    fn partition_size(&self, _partition_id: usize) -> Result<usize> {
300        Ok(0)
301    }
302
303    fn total_loss(&self) -> Option<f64> {
304        None
305    }
306}