1use std::sync::Arc;
5
6use arrow_array::{make_array, BooleanArray, RecordBatch, RecordBatchOptions, UInt64Array};
7use arrow_buffer::NullBuffer;
8use futures::{
9 future::BoxFuture,
10 stream::{BoxStream, FuturesOrdered},
11 FutureExt, Stream, StreamExt,
12};
13use lance_arrow::RecordBatchExt;
14use lance_core::{
15 utils::{address::RowAddress, deletion::DeletionVector},
16 Result, ROW_ADDR, ROW_ADDR_FIELD, ROW_CREATED_AT_VERSION_FIELD, ROW_ID, ROW_ID_FIELD,
17 ROW_LAST_UPDATED_AT_VERSION_FIELD,
18};
19use lance_io::ReadBatchParams;
20use tracing::{instrument, Instrument};
21
22use crate::rowids::RowIdSequence;
23
24pub type ReadBatchFut = BoxFuture<'static, Result<RecordBatch>>;
25pub struct ReadBatchTask {
28 pub task: ReadBatchFut,
29 pub num_rows: u32,
30}
31pub type ReadBatchTaskStream = BoxStream<'static, ReadBatchTask>;
32pub type ReadBatchFutStream = BoxStream<'static, ReadBatchFut>;
33
34struct MergeStream {
35 streams: Vec<ReadBatchTaskStream>,
36 next_batch: FuturesOrdered<ReadBatchFut>,
37 next_num_rows: u32,
38 index: usize,
39}
40
41impl MergeStream {
42 fn emit(&mut self) -> ReadBatchTask {
43 let mut iter = std::mem::take(&mut self.next_batch);
44 let task = async move {
45 let mut batch = iter.next().await.unwrap()?;
46 while let Some(next) = iter.next().await {
47 let next = next?;
48 batch = batch.merge(&next)?;
49 }
50 Ok(batch)
51 }
52 .boxed();
53 let num_rows = self.next_num_rows;
54 self.next_num_rows = 0;
55 ReadBatchTask { task, num_rows }
56 }
57}
58
59impl Stream for MergeStream {
60 type Item = ReadBatchTask;
61
62 fn poll_next(
63 mut self: std::pin::Pin<&mut Self>,
64 cx: &mut std::task::Context<'_>,
65 ) -> std::task::Poll<Option<Self::Item>> {
66 loop {
67 let index = self.index;
68 match self.streams[index].poll_next_unpin(cx) {
69 std::task::Poll::Ready(Some(batch_task)) => {
70 if self.index == 0 {
71 self.next_num_rows = batch_task.num_rows;
72 } else {
73 debug_assert_eq!(self.next_num_rows, batch_task.num_rows);
74 }
75 self.next_batch.push_back(batch_task.task);
76 self.index += 1;
77 if self.index == self.streams.len() {
78 self.index = 0;
79 let next_batch = self.emit();
80 return std::task::Poll::Ready(Some(next_batch));
81 }
82 }
83 std::task::Poll::Ready(None) => {
84 return std::task::Poll::Ready(None);
85 }
86 std::task::Poll::Pending => {
87 return std::task::Poll::Pending;
88 }
89 }
90 }
91 }
92}
93
94pub fn merge_streams(streams: Vec<ReadBatchTaskStream>) -> ReadBatchTaskStream {
107 MergeStream {
108 streams,
109 next_batch: FuturesOrdered::new(),
110 next_num_rows: 0,
111 index: 0,
112 }
113 .boxed()
114}
115
116fn apply_deletions_as_nulls(batch: RecordBatch, mask: &BooleanArray) -> Result<RecordBatch> {
123 let mask_buffer = NullBuffer::new(mask.values().clone());
127
128 if mask_buffer.null_count() == 0 {
129 return Ok(batch);
131 }
132
133 let new_columns = batch
135 .schema()
136 .fields()
137 .iter()
138 .zip(batch.columns())
139 .map(|(field, col)| {
140 if field.name() == ROW_ID || field.name() == ROW_ADDR {
141 let col_data = col.to_data();
142 let null_buffer = NullBuffer::union(col_data.nulls(), Some(&mask_buffer));
145
146 Ok(col_data
147 .into_builder()
148 .null_bit_buffer(null_buffer.map(|b| b.buffer().clone()))
149 .build()
150 .map(make_array)?)
151 } else {
152 Ok(col.clone())
153 }
154 })
155 .collect::<Result<Vec<_>>>()?;
156
157 Ok(RecordBatch::try_new_with_options(
158 batch.schema(),
159 new_columns,
160 &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())),
161 )?)
162}
163
164#[derive(Debug)]
166pub struct RowIdAndDeletesConfig {
167 pub params: ReadBatchParams,
169 pub with_row_id: bool,
171 pub with_row_addr: bool,
173 pub with_row_last_updated_at_version: bool,
175 pub with_row_created_at_version: bool,
177 pub deletion_vector: Option<Arc<DeletionVector>>,
179 pub row_id_sequence: Option<Arc<RowIdSequence>>,
181 pub last_updated_at_sequence: Option<Arc<crate::rowids::version::RowDatasetVersionSequence>>,
183 pub created_at_sequence: Option<Arc<crate::rowids::version::RowDatasetVersionSequence>>,
185 pub make_deletions_null: bool,
187 pub total_num_rows: u32,
191}
192
193impl RowIdAndDeletesConfig {
194 fn has_system_cols(&self) -> bool {
195 self.with_row_id
196 || self.with_row_addr
197 || self.with_row_last_updated_at_version
198 || self.with_row_created_at_version
199 }
200}
201
202#[instrument(level = "debug", skip_all)]
203pub fn apply_row_id_and_deletes(
204 batch: RecordBatch,
205 batch_offset: u32,
206 fragment_id: u32,
207 config: &RowIdAndDeletesConfig,
208) -> Result<RecordBatch> {
209 let mut deletion_vector = config.deletion_vector.as_ref();
210 if let Some(deletion_vector_inner) = deletion_vector {
212 if matches!(deletion_vector_inner.as_ref(), DeletionVector::NoDeletions) {
213 deletion_vector = None;
214 }
215 }
216 let has_deletions = deletion_vector.is_some();
217 debug_assert!(batch.num_columns() > 0 || config.has_system_cols() || has_deletions);
218
219 let should_fetch_row_addr = config.with_row_addr
221 || (config.with_row_id && config.row_id_sequence.is_none())
222 || has_deletions;
223
224 let num_rows = batch.num_rows() as u32;
225
226 let row_addrs =
227 if should_fetch_row_addr {
228 let _rowaddrs = tracing::span!(tracing::Level::DEBUG, "fetch_row_addrs").entered();
229 let mut row_addrs = Vec::with_capacity(num_rows as usize);
230 for offset_range in config
231 .params
232 .slice(batch_offset as usize, num_rows as usize)
233 .unwrap()
234 .iter_offset_ranges()?
235 {
236 row_addrs.extend(offset_range.map(|row_offset| {
237 u64::from(RowAddress::new_from_parts(fragment_id, row_offset))
238 }));
239 }
240
241 Some(Arc::new(UInt64Array::from(row_addrs)))
242 } else {
243 None
244 };
245
246 let row_ids = if config.with_row_id {
247 let _rowids = tracing::span!(tracing::Level::DEBUG, "fetch_row_ids").entered();
248 if let Some(row_id_sequence) = &config.row_id_sequence {
249 let selection = config
250 .params
251 .slice(batch_offset as usize, num_rows as usize)
252 .unwrap()
253 .to_ranges()
254 .unwrap();
255 let row_ids = row_id_sequence
256 .select(
257 selection
258 .iter()
259 .flat_map(|r| r.start as usize..r.end as usize),
260 )
261 .collect::<UInt64Array>();
262 Some(Arc::new(row_ids))
263 } else {
264 row_addrs.clone()
267 }
268 } else {
269 None
270 };
271
272 let span = tracing::span!(tracing::Level::DEBUG, "apply_deletions");
273 let _enter = span.enter();
274 let deletion_mask = deletion_vector.and_then(|v| {
275 let row_addrs: &[u64] = row_addrs.as_ref().unwrap().values();
276 v.build_predicate(row_addrs.iter())
277 });
278
279 let batch = if config.with_row_id {
280 let row_id_arr = row_ids.unwrap();
281 batch.try_with_column(ROW_ID_FIELD.clone(), row_id_arr)?
282 } else {
283 batch
284 };
285
286 let batch = if config.with_row_addr {
287 let row_addr_arr = row_addrs.unwrap();
288 batch.try_with_column(ROW_ADDR_FIELD.clone(), row_addr_arr)?
289 } else {
290 batch
291 };
292
293 let batch = if config.with_row_last_updated_at_version || config.with_row_created_at_version {
295 let mut batch = batch;
296
297 if config.with_row_last_updated_at_version {
298 let version_arr = if let Some(sequence) = &config.last_updated_at_sequence {
299 let selection = config
301 .params
302 .slice(batch_offset as usize, num_rows as usize)
303 .unwrap()
304 .to_ranges()
305 .unwrap();
306 let versions: Vec<u64> = selection
308 .iter()
309 .flat_map(|r| {
310 sequence
311 .versions()
312 .skip(r.start as usize)
313 .take((r.end - r.start) as usize)
314 })
315 .collect();
316 Arc::new(UInt64Array::from(versions))
317 } else {
318 Arc::new(UInt64Array::from(vec![1u64; num_rows as usize]))
320 };
321 batch =
322 batch.try_with_column(ROW_LAST_UPDATED_AT_VERSION_FIELD.clone(), version_arr)?;
323 }
324
325 if config.with_row_created_at_version {
326 let version_arr = if let Some(sequence) = &config.created_at_sequence {
327 let selection = config
329 .params
330 .slice(batch_offset as usize, num_rows as usize)
331 .unwrap()
332 .to_ranges()
333 .unwrap();
334 let versions: Vec<u64> = selection
336 .iter()
337 .flat_map(|r| {
338 sequence
339 .versions()
340 .skip(r.start as usize)
341 .take((r.end - r.start) as usize)
342 })
343 .collect();
344 Arc::new(UInt64Array::from(versions))
345 } else {
346 Arc::new(UInt64Array::from(vec![1u64; num_rows as usize]))
348 };
349 batch = batch.try_with_column(ROW_CREATED_AT_VERSION_FIELD.clone(), version_arr)?;
350 }
351
352 batch
353 } else {
354 batch
355 };
356
357 match (deletion_mask, config.make_deletions_null) {
358 (None, _) => Ok(batch),
359 (Some(mask), false) => Ok(arrow::compute::filter_record_batch(&batch, &mask)?),
360 (Some(mask), true) => Ok(apply_deletions_as_nulls(batch, &mask)?),
361 }
362}
363
364pub fn wrap_with_row_id_and_delete(
370 stream: ReadBatchTaskStream,
371 fragment_id: u32,
372 config: RowIdAndDeletesConfig,
373) -> ReadBatchFutStream {
374 let config = Arc::new(config);
375 let mut offset = 0;
376 stream
377 .map(move |batch_task| {
378 let config = config.clone();
379 let this_offset = offset;
380 let num_rows = batch_task.num_rows;
381 offset += num_rows;
382 let task = batch_task.task;
383 tokio::spawn(
384 async move {
385 let batch = task.await?;
386 apply_row_id_and_deletes(batch, this_offset, fragment_id, config.as_ref())
387 }
388 .in_current_span(),
389 )
390 .map(|join_wrapper| join_wrapper.unwrap())
391 .boxed()
392 })
393 .boxed()
394}
395
396#[cfg(test)]
397mod tests {
398 use std::sync::Arc;
399
400 use arrow::{array::AsArray, datatypes::UInt64Type};
401 use arrow_array::{types::Int32Type, RecordBatch, UInt32Array};
402 use arrow_schema::ArrowError;
403 use futures::{stream::BoxStream, FutureExt, StreamExt, TryStreamExt};
404 use lance_core::{
405 utils::{address::RowAddress, deletion::DeletionVector},
406 ROW_ID,
407 };
408 use lance_datagen::{BatchCount, RowCount};
409 use lance_io::{stream::arrow_stream_to_lance_stream, ReadBatchParams};
410 use roaring::RoaringBitmap;
411
412 use crate::utils::stream::ReadBatchTask;
413
414 use super::RowIdAndDeletesConfig;
415
416 fn batch_task_stream(
417 datagen_stream: BoxStream<'static, std::result::Result<RecordBatch, ArrowError>>,
418 ) -> super::ReadBatchTaskStream {
419 arrow_stream_to_lance_stream(datagen_stream)
420 .map(|batch| ReadBatchTask {
421 num_rows: batch.as_ref().unwrap().num_rows() as u32,
422 task: std::future::ready(batch).boxed(),
423 })
424 .boxed()
425 }
426
427 #[tokio::test]
428 async fn test_basic_zip() {
429 let left = batch_task_stream(
430 lance_datagen::gen_batch()
431 .col("x", lance_datagen::array::step::<Int32Type>())
432 .into_reader_stream(RowCount::from(100), BatchCount::from(10))
433 .0,
434 );
435 let right = batch_task_stream(
436 lance_datagen::gen_batch()
437 .col("y", lance_datagen::array::step::<Int32Type>())
438 .into_reader_stream(RowCount::from(100), BatchCount::from(10))
439 .0,
440 );
441
442 let merged = super::merge_streams(vec![left, right])
443 .map(|batch_task| batch_task.task)
444 .buffered(1)
445 .try_collect::<Vec<_>>()
446 .await
447 .unwrap();
448
449 let expected = lance_datagen::gen_batch()
450 .col("x", lance_datagen::array::step::<Int32Type>())
451 .col("y", lance_datagen::array::step::<Int32Type>())
452 .into_reader_rows(RowCount::from(100), BatchCount::from(10))
453 .collect::<Result<Vec<_>, ArrowError>>()
454 .unwrap();
455 assert_eq!(merged, expected);
456 }
457
458 async fn check_row_id(params: ReadBatchParams, expected: impl IntoIterator<Item = u32>) {
459 let expected = Vec::from_iter(expected);
460
461 for has_columns in [false, true] {
462 for fragment_id in [0, 10] {
463 let mut datagen = lance_datagen::gen_batch();
465 if has_columns {
466 datagen = datagen.col("x", lance_datagen::array::rand::<Int32Type>());
467 }
468 let data = batch_task_stream(
469 datagen
470 .into_reader_stream(RowCount::from(10), BatchCount::from(10))
471 .0,
472 );
473
474 let config = RowIdAndDeletesConfig {
475 params: params.clone(),
476 with_row_id: true,
477 with_row_addr: false,
478 with_row_last_updated_at_version: false,
479 with_row_created_at_version: false,
480 deletion_vector: None,
481 row_id_sequence: None,
482 last_updated_at_sequence: None,
483 created_at_sequence: None,
484 make_deletions_null: false,
485 total_num_rows: 100,
486 };
487 let stream = super::wrap_with_row_id_and_delete(data, fragment_id, config);
488 let batches = stream.buffered(1).try_collect::<Vec<_>>().await.unwrap();
489
490 let mut offset = 0;
491 let expected = expected.clone();
492 for batch in batches {
493 let actual_row_ids =
494 batch[ROW_ID].as_primitive::<UInt64Type>().values().to_vec();
495 let expected_row_ids = expected[offset..offset + 10]
496 .iter()
497 .map(|row_offset| {
498 RowAddress::new_from_parts(fragment_id, *row_offset).into()
499 })
500 .collect::<Vec<u64>>();
501 assert_eq!(actual_row_ids, expected_row_ids);
502 offset += batch.num_rows();
503 }
504 }
505 }
506 }
507
508 #[tokio::test]
509 async fn test_row_id() {
510 let some_indices = (0..100).rev().collect::<Vec<u32>>();
511 let some_indices_arr = UInt32Array::from(some_indices.clone());
512 check_row_id(ReadBatchParams::RangeFull, 0..100).await;
513 check_row_id(ReadBatchParams::Indices(some_indices_arr), some_indices).await;
514 check_row_id(ReadBatchParams::Range(1000..1100), 1000..1100).await;
515 check_row_id(
516 ReadBatchParams::RangeFrom(std::ops::RangeFrom { start: 1000 }),
517 1000..1100,
518 )
519 .await;
520 check_row_id(
521 ReadBatchParams::RangeTo(std::ops::RangeTo { end: 1000 }),
522 0..100,
523 )
524 .await;
525 }
526
527 #[tokio::test]
528 async fn test_deletes() {
529 let no_deletes: Option<Arc<DeletionVector>> = None;
530 let no_deletes_2 = Some(Arc::new(DeletionVector::NoDeletions));
531 let delete_some_bitmap = Some(Arc::new(DeletionVector::Bitmap(RoaringBitmap::from_iter(
532 0..35,
533 ))));
534 let delete_some_set = Some(Arc::new(DeletionVector::Set((0..35).collect())));
535
536 for deletion_vector in [
537 no_deletes,
538 no_deletes_2,
539 delete_some_bitmap,
540 delete_some_set,
541 ] {
542 for has_columns in [false, true] {
543 for with_row_id in [false, true] {
544 for make_deletions_null in [false, true] {
545 for frag_id in [0, 1] {
546 let has_deletions = if let Some(dv) = &deletion_vector {
547 !matches!(dv.as_ref(), DeletionVector::NoDeletions)
548 } else {
549 false
550 };
551 if !has_columns && !has_deletions && !with_row_id {
552 continue;
555 }
556 if make_deletions_null && !with_row_id {
557 continue;
560 }
561
562 let mut datagen = lance_datagen::gen_batch();
563 if has_columns {
564 datagen =
565 datagen.col("x", lance_datagen::array::rand::<Int32Type>());
566 }
567 let data = batch_task_stream(
569 datagen
570 .into_reader_stream(RowCount::from(10), BatchCount::from(10))
571 .0,
572 );
573
574 let config = RowIdAndDeletesConfig {
575 params: ReadBatchParams::RangeFull,
576 with_row_id,
577 with_row_addr: false,
578 with_row_last_updated_at_version: false,
579 with_row_created_at_version: false,
580 deletion_vector: deletion_vector.clone(),
581 row_id_sequence: None,
582 last_updated_at_sequence: None,
583 created_at_sequence: None,
584 make_deletions_null,
585 total_num_rows: 100,
586 };
587 let stream = super::wrap_with_row_id_and_delete(data, frag_id, config);
588 let batches = stream
589 .buffered(1)
590 .filter_map(|batch| {
591 std::future::ready(
592 batch
593 .map(|batch| {
594 if batch.num_rows() == 0 {
595 None
596 } else {
597 Some(batch)
598 }
599 })
600 .transpose(),
601 )
602 })
603 .try_collect::<Vec<_>>()
604 .await
605 .unwrap();
606
607 let total_num_rows =
608 batches.iter().map(|b| b.num_rows()).sum::<usize>();
609 let total_num_nulls = if make_deletions_null {
610 batches
611 .iter()
612 .map(|b| b[ROW_ID].null_count())
613 .sum::<usize>()
614 } else {
615 0
616 };
617 let total_actually_deleted = total_num_nulls + (100 - total_num_rows);
618
619 let expected_deletions = match &deletion_vector {
620 None => 0,
621 Some(deletion_vector) => match deletion_vector.as_ref() {
622 DeletionVector::NoDeletions => 0,
623 DeletionVector::Bitmap(b) => b.len() as usize,
624 DeletionVector::Set(s) => s.len(),
625 },
626 };
627 assert_eq!(total_actually_deleted, expected_deletions);
628 if expected_deletions > 0 && with_row_id {
629 if make_deletions_null {
630 assert_eq!(
633 batches[3][ROW_ID].as_primitive::<UInt64Type>().value(0),
634 u64::from(RowAddress::new_from_parts(frag_id, 30))
635 );
636 assert_eq!(batches[3][ROW_ID].null_count(), 5);
637 } else {
638 assert_eq!(
640 batches[0][ROW_ID].as_primitive::<UInt64Type>().value(0),
641 u64::from(RowAddress::new_from_parts(frag_id, 35))
642 );
643 }
644 }
645 if !with_row_id {
646 assert!(batches[0].column_by_name(ROW_ID).is_none());
647 }
648 }
649 }
650 }
651 }
652 }
653 }
654}