1use std::sync::Arc;
5
6use arrow_array::{make_array, BooleanArray, RecordBatch, RecordBatchOptions, UInt64Array};
7use arrow_buffer::NullBuffer;
8use futures::{
9 future::BoxFuture,
10 stream::{BoxStream, FuturesOrdered},
11 FutureExt, Stream, StreamExt,
12};
13use lance_arrow::RecordBatchExt;
14use lance_core::{
15 utils::{address::RowAddress, deletion::DeletionVector},
16 Result, ROW_ADDR, ROW_ADDR_FIELD, ROW_ID, ROW_ID_FIELD,
17};
18use lance_io::ReadBatchParams;
19
20use crate::rowids::RowIdSequence;
21
22pub type ReadBatchFut = BoxFuture<'static, Result<RecordBatch>>;
23pub struct ReadBatchTask {
26 pub task: ReadBatchFut,
27 pub num_rows: u32,
28}
29pub type ReadBatchTaskStream = BoxStream<'static, ReadBatchTask>;
30pub type ReadBatchFutStream = BoxStream<'static, ReadBatchFut>;
31
32struct MergeStream {
33 streams: Vec<ReadBatchTaskStream>,
34 next_batch: FuturesOrdered<ReadBatchFut>,
35 next_num_rows: u32,
36 index: usize,
37}
38
39impl MergeStream {
40 fn emit(&mut self) -> ReadBatchTask {
41 let mut iter = std::mem::take(&mut self.next_batch);
42 let task = async move {
43 let mut batch = iter.next().await.unwrap()?;
44 while let Some(next) = iter.next().await {
45 let next = next?;
46 batch = batch.merge(&next)?;
47 }
48 Ok(batch)
49 }
50 .boxed();
51 let num_rows = self.next_num_rows;
52 self.next_num_rows = 0;
53 ReadBatchTask { task, num_rows }
54 }
55}
56
57impl Stream for MergeStream {
58 type Item = ReadBatchTask;
59
60 fn poll_next(
61 mut self: std::pin::Pin<&mut Self>,
62 cx: &mut std::task::Context<'_>,
63 ) -> std::task::Poll<Option<Self::Item>> {
64 loop {
65 let index = self.index;
66 match self.streams[index].poll_next_unpin(cx) {
67 std::task::Poll::Ready(Some(batch_task)) => {
68 if self.index == 0 {
69 self.next_num_rows = batch_task.num_rows;
70 } else {
71 debug_assert_eq!(self.next_num_rows, batch_task.num_rows);
72 }
73 self.next_batch.push_back(batch_task.task);
74 self.index += 1;
75 if self.index == self.streams.len() {
76 self.index = 0;
77 let next_batch = self.emit();
78 return std::task::Poll::Ready(Some(next_batch));
79 }
80 }
81 std::task::Poll::Ready(None) => {
82 return std::task::Poll::Ready(None);
83 }
84 std::task::Poll::Pending => {
85 return std::task::Poll::Pending;
86 }
87 }
88 }
89 }
90}
91
92pub fn merge_streams(streams: Vec<ReadBatchTaskStream>) -> ReadBatchTaskStream {
105 MergeStream {
106 streams,
107 next_batch: FuturesOrdered::new(),
108 next_num_rows: 0,
109 index: 0,
110 }
111 .boxed()
112}
113
114fn apply_deletions_as_nulls(batch: RecordBatch, mask: &BooleanArray) -> Result<RecordBatch> {
117 let mask_buffer = NullBuffer::new(mask.values().clone());
121
122 match mask_buffer.null_count() {
123 n if n == mask_buffer.len() => return Ok(RecordBatch::new_empty(batch.schema())),
125 0 => return Ok(batch),
127 _ => {}
128 }
129
130 let new_columns = batch
132 .schema()
133 .fields()
134 .iter()
135 .zip(batch.columns())
136 .map(|(field, col)| {
137 if field.name() == ROW_ID || field.name() == ROW_ADDR {
138 let col_data = col.to_data();
139 let null_buffer = NullBuffer::union(col_data.nulls(), Some(&mask_buffer));
142
143 Ok(col_data
144 .into_builder()
145 .null_bit_buffer(null_buffer.map(|b| b.buffer().clone()))
146 .build()
147 .map(make_array)?)
148 } else {
149 Ok(col.clone())
150 }
151 })
152 .collect::<Result<Vec<_>>>()?;
153
154 Ok(RecordBatch::try_new_with_options(
155 batch.schema(),
156 new_columns,
157 &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())),
158 )?)
159}
160
161pub struct RowIdAndDeletesConfig {
163 pub params: ReadBatchParams,
165 pub with_row_id: bool,
167 pub with_row_addr: bool,
169 pub deletion_vector: Option<Arc<DeletionVector>>,
171 pub row_id_sequence: Option<Arc<RowIdSequence>>,
173 pub make_deletions_null: bool,
175 pub total_num_rows: u32,
179}
180
181pub fn apply_row_id_and_deletes(
182 batch: RecordBatch,
183 batch_offset: u32,
184 fragment_id: u32,
185 config: &RowIdAndDeletesConfig,
186) -> Result<RecordBatch> {
187 let mut deletion_vector = config.deletion_vector.as_ref();
188 if let Some(deletion_vector_inner) = deletion_vector {
190 if matches!(deletion_vector_inner.as_ref(), DeletionVector::NoDeletions) {
191 deletion_vector = None;
192 }
193 }
194 let has_deletions = deletion_vector.is_some();
195 debug_assert!(
196 batch.num_columns() > 0 || config.with_row_id || config.with_row_addr || has_deletions
197 );
198
199 let should_fetch_row_addr = config.with_row_addr
201 || (config.with_row_id && config.row_id_sequence.is_none())
202 || has_deletions;
203
204 let num_rows = batch.num_rows() as u32;
205
206 let row_addrs = if should_fetch_row_addr {
207 let ids_in_batch = config
208 .params
209 .slice(batch_offset as usize, num_rows as usize)
210 .unwrap()
211 .to_offsets()
212 .unwrap();
213 let row_addrs: UInt64Array = ids_in_batch
214 .values()
215 .iter()
216 .map(|row_id| u64::from(RowAddress::new_from_parts(fragment_id, *row_id)))
217 .collect();
218
219 Some(Arc::new(row_addrs))
220 } else {
221 None
222 };
223
224 let row_ids = if config.with_row_id {
225 if let Some(row_id_sequence) = &config.row_id_sequence {
226 let row_ids = row_id_sequence
227 .slice(batch_offset as usize, num_rows as usize)
228 .iter()
229 .collect::<UInt64Array>();
230 Some(Arc::new(row_ids))
231 } else {
232 row_addrs.clone()
235 }
236 } else {
237 None
238 };
239
240 let span = tracing::span!(tracing::Level::DEBUG, "apply_deletions");
247 let _enter = span.enter();
248 let deletion_mask = deletion_vector.and_then(|v| {
249 let row_addrs: &[u64] = row_addrs.as_ref().unwrap().values();
250 v.build_predicate(row_addrs.iter())
251 });
252
253 let batch = if config.with_row_id {
254 let row_id_arr = row_ids.unwrap();
255 batch.try_with_column(ROW_ID_FIELD.clone(), row_id_arr)?
256 } else {
257 batch
258 };
259
260 let batch = if config.with_row_addr {
261 let row_addr_arr = row_addrs.unwrap();
262 batch.try_with_column(ROW_ADDR_FIELD.clone(), row_addr_arr)?
263 } else {
264 batch
265 };
266
267 match (deletion_mask, config.make_deletions_null) {
268 (None, _) => Ok(batch),
269 (Some(mask), false) => Ok(arrow::compute::filter_record_batch(&batch, &mask)?),
270 (Some(mask), true) => Ok(apply_deletions_as_nulls(batch, &mask)?),
271 }
272}
273
274pub fn wrap_with_row_id_and_delete(
280 stream: ReadBatchTaskStream,
281 fragment_id: u32,
282 config: RowIdAndDeletesConfig,
283) -> ReadBatchFutStream {
284 let config = Arc::new(config);
285 let mut offset = 0;
286 stream
287 .map(move |batch_task| {
288 let config = config.clone();
289 let this_offset = offset;
290 let num_rows = batch_task.num_rows;
291 offset += num_rows;
292 let task = batch_task.task;
293 async move {
294 let batch = task.await?;
295 apply_row_id_and_deletes(batch, this_offset, fragment_id, config.as_ref())
296 }
297 .boxed()
298 })
299 .boxed()
300}
301
302#[cfg(test)]
303mod tests {
304 use std::sync::Arc;
305
306 use arrow::{array::AsArray, datatypes::UInt64Type};
307 use arrow_array::{types::Int32Type, RecordBatch, UInt32Array};
308 use arrow_schema::ArrowError;
309 use futures::{stream::BoxStream, FutureExt, StreamExt, TryStreamExt};
310 use lance_core::{
311 utils::{address::RowAddress, deletion::DeletionVector},
312 ROW_ID,
313 };
314 use lance_datagen::{BatchCount, RowCount};
315 use lance_io::{stream::arrow_stream_to_lance_stream, ReadBatchParams};
316 use roaring::RoaringBitmap;
317
318 use crate::utils::stream::ReadBatchTask;
319
320 use super::RowIdAndDeletesConfig;
321
322 fn batch_task_stream(
323 datagen_stream: BoxStream<'static, std::result::Result<RecordBatch, ArrowError>>,
324 ) -> super::ReadBatchTaskStream {
325 arrow_stream_to_lance_stream(datagen_stream)
326 .map(|batch| ReadBatchTask {
327 num_rows: batch.as_ref().unwrap().num_rows() as u32,
328 task: std::future::ready(batch).boxed(),
329 })
330 .boxed()
331 }
332
333 #[tokio::test]
334 async fn test_basic_zip() {
335 let left = batch_task_stream(
336 lance_datagen::gen()
337 .col("x", lance_datagen::array::step::<Int32Type>())
338 .into_reader_stream(RowCount::from(100), BatchCount::from(10))
339 .0,
340 );
341 let right = batch_task_stream(
342 lance_datagen::gen()
343 .col("y", lance_datagen::array::step::<Int32Type>())
344 .into_reader_stream(RowCount::from(100), BatchCount::from(10))
345 .0,
346 );
347
348 let merged = super::merge_streams(vec![left, right])
349 .map(|batch_task| batch_task.task)
350 .buffered(1)
351 .try_collect::<Vec<_>>()
352 .await
353 .unwrap();
354
355 let expected = lance_datagen::gen()
356 .col("x", lance_datagen::array::step::<Int32Type>())
357 .col("y", lance_datagen::array::step::<Int32Type>())
358 .into_reader_rows(RowCount::from(100), BatchCount::from(10))
359 .collect::<Result<Vec<_>, ArrowError>>()
360 .unwrap();
361 assert_eq!(merged, expected);
362 }
363
364 async fn check_row_id(params: ReadBatchParams, expected: impl IntoIterator<Item = u32>) {
365 let expected = Vec::from_iter(expected);
366
367 for has_columns in [false, true] {
368 for fragment_id in [0, 10] {
369 let mut datagen = lance_datagen::gen();
371 if has_columns {
372 datagen = datagen.col("x", lance_datagen::array::rand::<Int32Type>());
373 }
374 let data = batch_task_stream(
375 datagen
376 .into_reader_stream(RowCount::from(10), BatchCount::from(10))
377 .0,
378 );
379
380 let config = RowIdAndDeletesConfig {
381 params: params.clone(),
382 with_row_id: true,
383 with_row_addr: false,
384 deletion_vector: None,
385 row_id_sequence: None,
386 make_deletions_null: false,
387 total_num_rows: 100,
388 };
389 let stream = super::wrap_with_row_id_and_delete(data, fragment_id, config);
390 let batches = stream.buffered(1).try_collect::<Vec<_>>().await.unwrap();
391
392 let mut offset = 0;
393 let expected = expected.clone();
394 for batch in batches {
395 let actual_row_ids =
396 batch[ROW_ID].as_primitive::<UInt64Type>().values().to_vec();
397 let expected_row_ids = expected[offset..offset + 10]
398 .iter()
399 .map(|row_offset| {
400 RowAddress::new_from_parts(fragment_id, *row_offset).into()
401 })
402 .collect::<Vec<u64>>();
403 assert_eq!(actual_row_ids, expected_row_ids);
404 offset += batch.num_rows();
405 }
406 }
407 }
408 }
409
410 #[tokio::test]
411 async fn test_row_id() {
412 let some_indices = (0..100).rev().collect::<Vec<u32>>();
413 let some_indices_arr = UInt32Array::from(some_indices.clone());
414 check_row_id(ReadBatchParams::RangeFull, 0..100).await;
415 check_row_id(ReadBatchParams::Indices(some_indices_arr), some_indices).await;
416 check_row_id(ReadBatchParams::Range(1000..1100), 1000..1100).await;
417 check_row_id(
418 ReadBatchParams::RangeFrom(std::ops::RangeFrom { start: 1000 }),
419 1000..1100,
420 )
421 .await;
422 check_row_id(
423 ReadBatchParams::RangeTo(std::ops::RangeTo { end: 1000 }),
424 0..100,
425 )
426 .await;
427 }
428
429 #[tokio::test]
430 async fn test_deletes() {
431 let no_deletes: Option<Arc<DeletionVector>> = None;
432 let no_deletes_2 = Some(Arc::new(DeletionVector::NoDeletions));
433 let delete_some_bitmap = Some(Arc::new(DeletionVector::Bitmap(RoaringBitmap::from_iter(
434 0..35,
435 ))));
436 let delete_some_set = Some(Arc::new(DeletionVector::Set((0..35).collect())));
437
438 for deletion_vector in [
439 no_deletes,
440 no_deletes_2,
441 delete_some_bitmap,
442 delete_some_set,
443 ] {
444 for has_columns in [false, true] {
445 for with_row_id in [false, true] {
446 for make_deletions_null in [false, true] {
447 for frag_id in [0, 1] {
448 let has_deletions = if let Some(dv) = &deletion_vector {
449 !matches!(dv.as_ref(), DeletionVector::NoDeletions)
450 } else {
451 false
452 };
453 if !has_columns && !has_deletions && !with_row_id {
454 continue;
457 }
458 if make_deletions_null && !with_row_id {
459 continue;
462 }
463
464 let mut datagen = lance_datagen::gen();
465 if has_columns {
466 datagen =
467 datagen.col("x", lance_datagen::array::rand::<Int32Type>());
468 }
469 let data = batch_task_stream(
471 datagen
472 .into_reader_stream(RowCount::from(10), BatchCount::from(10))
473 .0,
474 );
475
476 let config = RowIdAndDeletesConfig {
477 params: ReadBatchParams::RangeFull,
478 with_row_id,
479 with_row_addr: false,
480 deletion_vector: deletion_vector.clone(),
481 row_id_sequence: None,
482 make_deletions_null,
483 total_num_rows: 100,
484 };
485 let stream = super::wrap_with_row_id_and_delete(data, frag_id, config);
486 let batches = stream
487 .buffered(1)
488 .filter_map(|batch| {
489 std::future::ready(
490 batch
491 .map(|batch| {
492 if batch.num_rows() == 0 {
493 None
494 } else {
495 Some(batch)
496 }
497 })
498 .transpose(),
499 )
500 })
501 .try_collect::<Vec<_>>()
502 .await
503 .unwrap();
504
505 let total_num_rows =
506 batches.iter().map(|b| b.num_rows()).sum::<usize>();
507 let total_num_nulls = if make_deletions_null {
508 batches
509 .iter()
510 .map(|b| b[ROW_ID].null_count())
511 .sum::<usize>()
512 } else {
513 0
514 };
515 let total_actually_deleted = total_num_nulls + (100 - total_num_rows);
516
517 let expected_deletions = match &deletion_vector {
518 None => 0,
519 Some(deletion_vector) => match deletion_vector.as_ref() {
520 DeletionVector::NoDeletions => 0,
521 DeletionVector::Bitmap(b) => b.len() as usize,
522 DeletionVector::Set(s) => s.len(),
523 },
524 };
525 assert_eq!(total_actually_deleted, expected_deletions);
526 if expected_deletions > 0 && with_row_id {
527 if make_deletions_null {
528 assert_eq!(
529 batches[0][ROW_ID].as_primitive::<UInt64Type>().value(0),
530 u64::from(RowAddress::new_from_parts(frag_id, 30))
531 );
532 } else {
533 assert_eq!(
534 batches[0][ROW_ID].as_primitive::<UInt64Type>().value(0),
535 u64::from(RowAddress::new_from_parts(frag_id, 35))
536 );
537 }
538 }
539 if !with_row_id {
540 assert!(batches[0].column_by_name(ROW_ID).is_none());
541 }
542 }
543 }
544 }
545 }
546 }
547 }
548}