1use std::sync::Arc;
8
9use arrow::{array::AsArray, compute::sort_to_indices};
10use arrow_array::{RecordBatch, UInt32Array};
11use arrow_schema::Schema;
12use future::try_join_all;
13use futures::prelude::*;
14use lance_arrow::{RecordBatchExt, SchemaExt};
15use lance_core::{
16 cache::LanceCache,
17 utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu},
18 Error, Result,
19};
20use lance_encoding::decoder::{DecoderPlugins, FilterExpression};
21use lance_file::reader::{FileReader, FileReaderOptions};
22use lance_file::writer::FileWriter;
23use lance_io::{
24 object_store::ObjectStore,
25 scheduler::{ScanScheduler, SchedulerConfig},
26 stream::{RecordBatchStream, RecordBatchStreamAdapter},
27 utils::CachedFileSize,
28};
29use object_store::path::Path;
30
31use crate::vector::{LOSS_METADATA_KEY, PART_ID_COLUMN};
32
33#[async_trait::async_trait]
34pub trait ShuffleReader: Send + Sync {
36 async fn read_partition(
40 &self,
41 partition_id: usize,
42 ) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>>;
43
44 fn partition_size(&self, partition_id: usize) -> Result<usize>;
46
47 fn total_loss(&self) -> Option<f64>;
52}
53
54#[async_trait::async_trait]
55pub trait Shuffler: Send + Sync {
58 async fn shuffle(
61 &self,
62 data: Box<dyn RecordBatchStream + Unpin + 'static>,
63 ) -> Result<Box<dyn ShuffleReader>>;
64}
65
66pub struct IvfShuffler {
67 object_store: Arc<ObjectStore>,
68 output_dir: Path,
69 num_partitions: usize,
70
71 buffer_size: usize,
73 precomputed_shuffle_buffers: Option<Vec<String>>,
74}
75
76impl IvfShuffler {
77 pub fn new(output_dir: Path, num_partitions: usize) -> Self {
78 Self {
79 object_store: Arc::new(ObjectStore::local()),
80 output_dir,
81 num_partitions,
82 buffer_size: 4096,
83 precomputed_shuffle_buffers: None,
84 }
85 }
86
87 pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
88 self.buffer_size = buffer_size;
89 self
90 }
91
92 pub fn with_precomputed_shuffle_buffers(
93 mut self,
94 precomputed_shuffle_buffers: Option<Vec<String>>,
95 ) -> Self {
96 self.precomputed_shuffle_buffers = precomputed_shuffle_buffers;
97 self
98 }
99}
100
101#[async_trait::async_trait]
102impl Shuffler for IvfShuffler {
103 async fn shuffle(
104 &self,
105 data: Box<dyn RecordBatchStream + Unpin + 'static>,
106 ) -> Result<Box<dyn ShuffleReader>> {
107 let num_partitions = self.num_partitions;
108 let mut partition_sizes = vec![0; num_partitions];
109 let schema = data.schema().without_column(PART_ID_COLUMN);
110 let mut writers = stream::iter(0..num_partitions)
111 .map(|partition_id| {
112 let part_path = self.output_dir.child(format!("ivf_{}.lance", partition_id));
113 let object_store = self.object_store.clone();
114 let schema = schema.clone();
115 async move {
116 let writer = object_store.create(&part_path).await?;
117 FileWriter::try_new(
118 writer,
119 lance_core::datatypes::Schema::try_from(&schema)?,
120 Default::default(),
121 )
122 }
123 })
124 .buffered(self.object_store.io_parallelism())
125 .try_collect::<Vec<_>>()
126 .await?;
127 let mut parallel_sort_stream = data
128 .map(|batch| {
129 spawn_cpu(move || {
130 let batch = batch?;
131
132 let loss = batch
133 .metadata()
134 .get(LOSS_METADATA_KEY)
135 .map(|s| s.parse::<f64>().unwrap_or_default())
136 .unwrap_or_default();
137
138 let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
139
140 let indices = sort_to_indices(&part_ids, None, None)?;
141 let batch = batch.take(&indices)?;
142
143 let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
144 let batch = batch.drop_column(PART_ID_COLUMN)?;
145
146 let mut partition_buffers = vec![Vec::new(); num_partitions];
147
148 let mut start = 0;
149 while start < batch.num_rows() {
150 let part_id: u32 = part_ids.value(start);
151 let mut end = start + 1;
152 while end < batch.num_rows() && part_ids.value(end) == part_id {
153 end += 1;
154 }
155
156 let part_batches = &mut partition_buffers[part_id as usize];
157 part_batches.push(batch.slice(start, end - start));
158 start = end;
159 }
160
161 Ok::<(Vec<Vec<RecordBatch>>, f64), Error>((partition_buffers, loss))
162 })
163 })
164 .buffered(get_num_compute_intensive_cpus());
165
166 let mut partition_buffers = vec![Vec::new(); num_partitions];
169
170 let mut counter = 0;
171 let mut total_loss = 0.0;
172 while let Some(shuffled) = parallel_sort_stream.next().await {
173 let (shuffled, loss) = shuffled?;
174 total_loss += loss;
175
176 for (part_id, batches) in shuffled.into_iter().enumerate() {
177 let part_batches = &mut partition_buffers[part_id];
178 part_batches.extend(batches);
179 }
180
181 counter += 1;
182
183 if counter % self.buffer_size == 0 {
185 let mut futs = vec![];
186 for (part_id, writer) in writers.iter_mut().enumerate() {
187 let batches = &partition_buffers[part_id];
188 partition_sizes[part_id] += batches.iter().map(|b| b.num_rows()).sum::<usize>();
189 futs.push(writer.write_batches(batches.iter()));
190 }
191 try_join_all(futs).await?;
192
193 partition_buffers.iter_mut().for_each(|b| b.clear());
194 }
195 }
196
197 for (part_id, batches) in partition_buffers.into_iter().enumerate() {
199 let writer = &mut writers[part_id];
200 partition_sizes[part_id] += batches.iter().map(|b| b.num_rows()).sum::<usize>();
201 for batch in batches.iter() {
202 writer.write_batch(batch).await?;
203 }
204 }
205
206 for writer in writers.iter_mut() {
208 writer.finish().await?;
209 }
210
211 Ok(Box::new(IvfShufflerReader::new(
212 self.object_store.clone(),
213 self.output_dir.clone(),
214 partition_sizes,
215 total_loss,
216 )))
217 }
218}
219
220pub struct IvfShufflerReader {
221 scheduler: Arc<ScanScheduler>,
222 output_dir: Path,
223 partition_sizes: Vec<usize>,
224 loss: f64,
225}
226
227impl IvfShufflerReader {
228 pub fn new(
229 object_store: Arc<ObjectStore>,
230 output_dir: Path,
231 partition_sizes: Vec<usize>,
232 loss: f64,
233 ) -> Self {
234 let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
235 let scheduler = ScanScheduler::new(object_store, scheduler_config);
236 Self {
237 scheduler,
238 output_dir,
239 partition_sizes,
240 loss,
241 }
242 }
243}
244
245#[async_trait::async_trait]
246impl ShuffleReader for IvfShufflerReader {
247 async fn read_partition(
248 &self,
249 partition_id: usize,
250 ) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>> {
251 if partition_id >= self.partition_sizes.len() {
252 return Ok(None);
253 }
254
255 let partition_path = self.output_dir.child(format!("ivf_{}.lance", partition_id));
256
257 let reader = FileReader::try_open(
258 self.scheduler
259 .open_file(&partition_path, &CachedFileSize::unknown())
260 .await?,
261 None,
262 Arc::<DecoderPlugins>::default(),
263 &LanceCache::no_cache(),
264 FileReaderOptions::default(),
265 )
266 .await?;
267 let schema: Schema = reader.schema().as_ref().into();
268 Ok(Some(Box::new(RecordBatchStreamAdapter::new(
269 Arc::new(schema),
270 reader.read_stream(
271 lance_io::ReadBatchParams::RangeFull,
272 u32::MAX,
273 16,
274 FilterExpression::no_filter(),
275 )?,
276 ))))
277 }
278
279 fn partition_size(&self, partition_id: usize) -> Result<usize> {
280 Ok(self.partition_sizes.get(partition_id).copied().unwrap_or(0))
281 }
282
283 fn total_loss(&self) -> Option<f64> {
284 Some(self.loss)
285 }
286}
287
288pub struct EmptyReader;
289
290#[async_trait::async_trait]
291impl ShuffleReader for EmptyReader {
292 async fn read_partition(
293 &self,
294 _partition_id: usize,
295 ) -> Result<Option<Box<dyn RecordBatchStream + Unpin + 'static>>> {
296 Ok(None)
297 }
298
299 fn partition_size(&self, _partition_id: usize) -> Result<usize> {
300 Ok(0)
301 }
302
303 fn total_loss(&self) -> Option<f64> {
304 None
305 }
306}