lance_table/utils/
stream.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::sync::Arc;
5
6use arrow_array::{make_array, BooleanArray, RecordBatch, RecordBatchOptions, UInt64Array};
7use arrow_buffer::NullBuffer;
8use futures::{
9    future::BoxFuture,
10    stream::{BoxStream, FuturesOrdered},
11    FutureExt, Stream, StreamExt,
12};
13use lance_arrow::RecordBatchExt;
14use lance_core::{
15    utils::{address::RowAddress, deletion::DeletionVector},
16    Result, ROW_ADDR, ROW_ADDR_FIELD, ROW_CREATED_AT_VERSION_FIELD, ROW_ID, ROW_ID_FIELD,
17    ROW_LAST_UPDATED_AT_VERSION_FIELD,
18};
19use lance_io::ReadBatchParams;
20use tracing::instrument;
21
22use crate::rowids::RowIdSequence;
23
24pub type ReadBatchFut = BoxFuture<'static, Result<RecordBatch>>;
25/// A task, emitted by a file reader, that will produce a batch (of the
26/// given size)
27pub struct ReadBatchTask {
28    pub task: ReadBatchFut,
29    pub num_rows: u32,
30}
31pub type ReadBatchTaskStream = BoxStream<'static, ReadBatchTask>;
32pub type ReadBatchFutStream = BoxStream<'static, ReadBatchFut>;
33
34struct MergeStream {
35    streams: Vec<ReadBatchTaskStream>,
36    next_batch: FuturesOrdered<ReadBatchFut>,
37    next_num_rows: u32,
38    index: usize,
39}
40
41impl MergeStream {
42    fn emit(&mut self) -> ReadBatchTask {
43        let mut iter = std::mem::take(&mut self.next_batch);
44        let task = async move {
45            let mut batch = iter.next().await.unwrap()?;
46            while let Some(next) = iter.next().await {
47                let next = next?;
48                batch = batch.merge(&next)?;
49            }
50            Ok(batch)
51        }
52        .boxed();
53        let num_rows = self.next_num_rows;
54        self.next_num_rows = 0;
55        ReadBatchTask { task, num_rows }
56    }
57}
58
59impl Stream for MergeStream {
60    type Item = ReadBatchTask;
61
62    fn poll_next(
63        mut self: std::pin::Pin<&mut Self>,
64        cx: &mut std::task::Context<'_>,
65    ) -> std::task::Poll<Option<Self::Item>> {
66        loop {
67            let index = self.index;
68            match self.streams[index].poll_next_unpin(cx) {
69                std::task::Poll::Ready(Some(batch_task)) => {
70                    if self.index == 0 {
71                        self.next_num_rows = batch_task.num_rows;
72                    } else {
73                        debug_assert_eq!(self.next_num_rows, batch_task.num_rows);
74                    }
75                    self.next_batch.push_back(batch_task.task);
76                    self.index += 1;
77                    if self.index == self.streams.len() {
78                        self.index = 0;
79                        let next_batch = self.emit();
80                        return std::task::Poll::Ready(Some(next_batch));
81                    }
82                }
83                std::task::Poll::Ready(None) => {
84                    return std::task::Poll::Ready(None);
85                }
86                std::task::Poll::Pending => {
87                    return std::task::Poll::Pending;
88                }
89            }
90        }
91    }
92}
93
94/// Given multiple streams of batch tasks, merge them into a single stream
95///
96/// This pulls one batch from each stream and then combines the columns from
97/// all of the batches into a single batch.  The order of the batches in the
98/// streams is maintained and the merged batch columns will be in order from
99/// first to last stream.
100///
101/// This stream ends as soon as any of the input streams ends (we do not
102/// verify that the other input streams are finished as well)
103///
104/// This will panic if any of the input streams return a batch with a different
105/// number of rows than the first stream.
106pub fn merge_streams(streams: Vec<ReadBatchTaskStream>) -> ReadBatchTaskStream {
107    MergeStream {
108        streams,
109        next_batch: FuturesOrdered::new(),
110        next_num_rows: 0,
111        index: 0,
112    }
113    .boxed()
114}
115
116/// Apply a mask to the batch, where rows are "deleted" by the _rowid column null.
117///
118/// This is used partly as a performance optimization (cheaper to null than to filter)
119/// but also because there are cases where we want to load the physical rows.  For example,
120/// we may be replacing a column based on some UDF and we want to provide a value for the
121/// deleted rows to ensure the fragments are aligned.
122fn apply_deletions_as_nulls(batch: RecordBatch, mask: &BooleanArray) -> Result<RecordBatch> {
123    // Transform mask into null buffer. Null means deleted, though note that
124    // null buffers are actually validity buffers, so True means not null
125    // and thus not deleted.
126    let mask_buffer = NullBuffer::new(mask.values().clone());
127
128    if mask_buffer.null_count() == 0 {
129        // No rows are deleted
130        return Ok(batch);
131    }
132
133    // For each column convert to data
134    let new_columns = batch
135        .schema()
136        .fields()
137        .iter()
138        .zip(batch.columns())
139        .map(|(field, col)| {
140            if field.name() == ROW_ID || field.name() == ROW_ADDR {
141                let col_data = col.to_data();
142                // If it already has a validity bitmap, then AND it with the mask.
143                // Otherwise, use the boolean buffer as the mask.
144                let null_buffer = NullBuffer::union(col_data.nulls(), Some(&mask_buffer));
145
146                Ok(col_data
147                    .into_builder()
148                    .null_bit_buffer(null_buffer.map(|b| b.buffer().clone()))
149                    .build()
150                    .map(make_array)?)
151            } else {
152                Ok(col.clone())
153            }
154        })
155        .collect::<Result<Vec<_>>>()?;
156
157    Ok(RecordBatch::try_new_with_options(
158        batch.schema(),
159        new_columns,
160        &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())),
161    )?)
162}
163
164/// Configuration needed to apply row ids and deletions to a batch
165#[derive(Debug)]
166pub struct RowIdAndDeletesConfig {
167    /// The row ids that were requested
168    pub params: ReadBatchParams,
169    /// Whether to include the row id column in the final batch
170    pub with_row_id: bool,
171    /// Whether to include the row address column in the final batch
172    pub with_row_addr: bool,
173    /// Whether to include the last updated at version column in the final batch
174    pub with_row_last_updated_at_version: bool,
175    /// Whether to include the created at version column in the final batch
176    pub with_row_created_at_version: bool,
177    /// An optional deletion vector to apply to the batch
178    pub deletion_vector: Option<Arc<DeletionVector>>,
179    /// An optional row id sequence to use for the row id column.
180    pub row_id_sequence: Option<Arc<RowIdSequence>>,
181    /// The last_updated_at version sequence
182    pub last_updated_at_sequence: Option<Arc<crate::rowids::version::RowDatasetVersionSequence>>,
183    /// The created_at version sequence
184    pub created_at_sequence: Option<Arc<crate::rowids::version::RowDatasetVersionSequence>>,
185    /// Whether to make deleted rows null instead of filtering them out
186    pub make_deletions_null: bool,
187    /// The total number of rows that will be loaded
188    ///
189    /// This is needed to convert ReadbatchParams::RangeTo into a valid range
190    pub total_num_rows: u32,
191}
192
193impl RowIdAndDeletesConfig {
194    fn has_system_cols(&self) -> bool {
195        self.with_row_id
196            || self.with_row_addr
197            || self.with_row_last_updated_at_version
198            || self.with_row_created_at_version
199    }
200}
201
202#[instrument(level = "debug", skip_all)]
203pub fn apply_row_id_and_deletes(
204    batch: RecordBatch,
205    batch_offset: u32,
206    fragment_id: u32,
207    config: &RowIdAndDeletesConfig,
208) -> Result<RecordBatch> {
209    let mut deletion_vector = config.deletion_vector.as_ref();
210    // Convert Some(NoDeletions) into None to simplify logic below
211    if let Some(deletion_vector_inner) = deletion_vector {
212        if matches!(deletion_vector_inner.as_ref(), DeletionVector::NoDeletions) {
213            deletion_vector = None;
214        }
215    }
216    let has_deletions = deletion_vector.is_some();
217    debug_assert!(batch.num_columns() > 0 || config.has_system_cols() || has_deletions);
218
219    // If row id sequence is None, then row id IS row address.
220    let should_fetch_row_addr = config.with_row_addr
221        || (config.with_row_id && config.row_id_sequence.is_none())
222        || has_deletions;
223
224    let num_rows = batch.num_rows() as u32;
225
226    let row_addrs =
227        if should_fetch_row_addr {
228            let _rowaddrs = tracing::span!(tracing::Level::DEBUG, "fetch_row_addrs").entered();
229            let mut row_addrs = Vec::with_capacity(num_rows as usize);
230            for offset_range in config
231                .params
232                .slice(batch_offset as usize, num_rows as usize)
233                .unwrap()
234                .iter_offset_ranges()?
235            {
236                row_addrs.extend(offset_range.map(|row_offset| {
237                    u64::from(RowAddress::new_from_parts(fragment_id, row_offset))
238                }));
239            }
240
241            Some(Arc::new(UInt64Array::from(row_addrs)))
242        } else {
243            None
244        };
245
246    let row_ids = if config.with_row_id {
247        let _rowids = tracing::span!(tracing::Level::DEBUG, "fetch_row_ids").entered();
248        if let Some(row_id_sequence) = &config.row_id_sequence {
249            let selection = config
250                .params
251                .slice(batch_offset as usize, num_rows as usize)
252                .unwrap()
253                .to_ranges()
254                .unwrap();
255            let row_ids = row_id_sequence
256                .select(
257                    selection
258                        .iter()
259                        .flat_map(|r| r.start as usize..r.end as usize),
260                )
261                .collect::<UInt64Array>();
262            Some(Arc::new(row_ids))
263        } else {
264            // If we don't have a row id sequence, can assume the row ids are
265            // the same as the row addresses.
266            row_addrs.clone()
267        }
268    } else {
269        None
270    };
271
272    let span = tracing::span!(tracing::Level::DEBUG, "apply_deletions");
273    let _enter = span.enter();
274    let deletion_mask = deletion_vector.and_then(|v| {
275        let row_addrs: &[u64] = row_addrs.as_ref().unwrap().values();
276        v.build_predicate(row_addrs.iter())
277    });
278
279    let batch = if config.with_row_id {
280        let row_id_arr = row_ids.unwrap();
281        batch.try_with_column(ROW_ID_FIELD.clone(), row_id_arr)?
282    } else {
283        batch
284    };
285
286    let batch = if config.with_row_addr {
287        let row_addr_arr = row_addrs.unwrap();
288        batch.try_with_column(ROW_ADDR_FIELD.clone(), row_addr_arr)?
289    } else {
290        batch
291    };
292
293    // Add version columns if requested
294    let batch = if config.with_row_last_updated_at_version || config.with_row_created_at_version {
295        let mut batch = batch;
296
297        if config.with_row_last_updated_at_version {
298            let version_arr = if let Some(sequence) = &config.last_updated_at_sequence {
299                // Get the range of rows for this batch
300                let selection = config
301                    .params
302                    .slice(batch_offset as usize, num_rows as usize)
303                    .unwrap()
304                    .to_ranges()
305                    .unwrap();
306                // Extract version values for the selected ranges
307                let versions: Vec<u64> = selection
308                    .iter()
309                    .flat_map(|r| {
310                        sequence
311                            .versions()
312                            .skip(r.start as usize)
313                            .take((r.end - r.start) as usize)
314                    })
315                    .collect();
316                Arc::new(UInt64Array::from(versions))
317            } else {
318                // Default to version 1 if sequence not provided
319                Arc::new(UInt64Array::from(vec![1u64; num_rows as usize]))
320            };
321            batch =
322                batch.try_with_column(ROW_LAST_UPDATED_AT_VERSION_FIELD.clone(), version_arr)?;
323        }
324
325        if config.with_row_created_at_version {
326            let version_arr = if let Some(sequence) = &config.created_at_sequence {
327                // Get the range of rows for this batch
328                let selection = config
329                    .params
330                    .slice(batch_offset as usize, num_rows as usize)
331                    .unwrap()
332                    .to_ranges()
333                    .unwrap();
334                // Extract version values for the selected ranges
335                let versions: Vec<u64> = selection
336                    .iter()
337                    .flat_map(|r| {
338                        sequence
339                            .versions()
340                            .skip(r.start as usize)
341                            .take((r.end - r.start) as usize)
342                    })
343                    .collect();
344                Arc::new(UInt64Array::from(versions))
345            } else {
346                // Default to version 1 if sequence not provided
347                Arc::new(UInt64Array::from(vec![1u64; num_rows as usize]))
348            };
349            batch = batch.try_with_column(ROW_CREATED_AT_VERSION_FIELD.clone(), version_arr)?;
350        }
351
352        batch
353    } else {
354        batch
355    };
356
357    match (deletion_mask, config.make_deletions_null) {
358        (None, _) => Ok(batch),
359        (Some(mask), false) => Ok(arrow::compute::filter_record_batch(&batch, &mask)?),
360        (Some(mask), true) => Ok(apply_deletions_as_nulls(batch, &mask)?),
361    }
362}
363
364/// Given a stream of batch tasks this function will add a row ids column (if requested)
365/// and also apply a deletions vector to the batch.
366///
367/// This converts from BatchTaskStream to BatchFutStream because, if we are applying a
368/// deletion vector, it is impossible to know how many output rows we will have.
369pub fn wrap_with_row_id_and_delete(
370    stream: ReadBatchTaskStream,
371    fragment_id: u32,
372    config: RowIdAndDeletesConfig,
373) -> ReadBatchFutStream {
374    let config = Arc::new(config);
375    let mut offset = 0;
376    stream
377        .map(move |batch_task| {
378            let config = config.clone();
379            let this_offset = offset;
380            let num_rows = batch_task.num_rows;
381            offset += num_rows;
382            batch_task
383                .task
384                .map(move |batch| {
385                    apply_row_id_and_deletes(batch?, this_offset, fragment_id, config.as_ref())
386                })
387                .boxed()
388        })
389        .boxed()
390}
391
392#[cfg(test)]
393mod tests {
394    use std::sync::Arc;
395
396    use arrow::{array::AsArray, datatypes::UInt64Type};
397    use arrow_array::{types::Int32Type, RecordBatch, UInt32Array};
398    use arrow_schema::ArrowError;
399    use futures::{stream::BoxStream, FutureExt, StreamExt, TryStreamExt};
400    use lance_core::{
401        utils::{address::RowAddress, deletion::DeletionVector},
402        ROW_ID,
403    };
404    use lance_datagen::{BatchCount, RowCount};
405    use lance_io::{stream::arrow_stream_to_lance_stream, ReadBatchParams};
406    use roaring::RoaringBitmap;
407
408    use crate::utils::stream::ReadBatchTask;
409
410    use super::RowIdAndDeletesConfig;
411
412    fn batch_task_stream(
413        datagen_stream: BoxStream<'static, std::result::Result<RecordBatch, ArrowError>>,
414    ) -> super::ReadBatchTaskStream {
415        arrow_stream_to_lance_stream(datagen_stream)
416            .map(|batch| ReadBatchTask {
417                num_rows: batch.as_ref().unwrap().num_rows() as u32,
418                task: std::future::ready(batch).boxed(),
419            })
420            .boxed()
421    }
422
423    #[tokio::test]
424    async fn test_basic_zip() {
425        let left = batch_task_stream(
426            lance_datagen::gen_batch()
427                .col("x", lance_datagen::array::step::<Int32Type>())
428                .into_reader_stream(RowCount::from(100), BatchCount::from(10))
429                .0,
430        );
431        let right = batch_task_stream(
432            lance_datagen::gen_batch()
433                .col("y", lance_datagen::array::step::<Int32Type>())
434                .into_reader_stream(RowCount::from(100), BatchCount::from(10))
435                .0,
436        );
437
438        let merged = super::merge_streams(vec![left, right])
439            .map(|batch_task| batch_task.task)
440            .buffered(1)
441            .try_collect::<Vec<_>>()
442            .await
443            .unwrap();
444
445        let expected = lance_datagen::gen_batch()
446            .col("x", lance_datagen::array::step::<Int32Type>())
447            .col("y", lance_datagen::array::step::<Int32Type>())
448            .into_reader_rows(RowCount::from(100), BatchCount::from(10))
449            .collect::<Result<Vec<_>, ArrowError>>()
450            .unwrap();
451        assert_eq!(merged, expected);
452    }
453
454    async fn check_row_id(params: ReadBatchParams, expected: impl IntoIterator<Item = u32>) {
455        let expected = Vec::from_iter(expected);
456
457        for has_columns in [false, true] {
458            for fragment_id in [0, 10] {
459                // 100 rows across 10 batches of 10 rows
460                let mut datagen = lance_datagen::gen_batch();
461                if has_columns {
462                    datagen = datagen.col("x", lance_datagen::array::rand::<Int32Type>());
463                }
464                let data = batch_task_stream(
465                    datagen
466                        .into_reader_stream(RowCount::from(10), BatchCount::from(10))
467                        .0,
468                );
469
470                let config = RowIdAndDeletesConfig {
471                    params: params.clone(),
472                    with_row_id: true,
473                    with_row_addr: false,
474                    with_row_last_updated_at_version: false,
475                    with_row_created_at_version: false,
476                    deletion_vector: None,
477                    row_id_sequence: None,
478                    last_updated_at_sequence: None,
479                    created_at_sequence: None,
480                    make_deletions_null: false,
481                    total_num_rows: 100,
482                };
483                let stream = super::wrap_with_row_id_and_delete(data, fragment_id, config);
484                let batches = stream.buffered(1).try_collect::<Vec<_>>().await.unwrap();
485
486                let mut offset = 0;
487                let expected = expected.clone();
488                for batch in batches {
489                    let actual_row_ids =
490                        batch[ROW_ID].as_primitive::<UInt64Type>().values().to_vec();
491                    let expected_row_ids = expected[offset..offset + 10]
492                        .iter()
493                        .map(|row_offset| {
494                            RowAddress::new_from_parts(fragment_id, *row_offset).into()
495                        })
496                        .collect::<Vec<u64>>();
497                    assert_eq!(actual_row_ids, expected_row_ids);
498                    offset += batch.num_rows();
499                }
500            }
501        }
502    }
503
504    #[tokio::test]
505    async fn test_row_id() {
506        let some_indices = (0..100).rev().collect::<Vec<u32>>();
507        let some_indices_arr = UInt32Array::from(some_indices.clone());
508        check_row_id(ReadBatchParams::RangeFull, 0..100).await;
509        check_row_id(ReadBatchParams::Indices(some_indices_arr), some_indices).await;
510        check_row_id(ReadBatchParams::Range(1000..1100), 1000..1100).await;
511        check_row_id(
512            ReadBatchParams::RangeFrom(std::ops::RangeFrom { start: 1000 }),
513            1000..1100,
514        )
515        .await;
516        check_row_id(
517            ReadBatchParams::RangeTo(std::ops::RangeTo { end: 1000 }),
518            0..100,
519        )
520        .await;
521    }
522
523    #[tokio::test]
524    async fn test_deletes() {
525        let no_deletes: Option<Arc<DeletionVector>> = None;
526        let no_deletes_2 = Some(Arc::new(DeletionVector::NoDeletions));
527        let delete_some_bitmap = Some(Arc::new(DeletionVector::Bitmap(RoaringBitmap::from_iter(
528            0..35,
529        ))));
530        let delete_some_set = Some(Arc::new(DeletionVector::Set((0..35).collect())));
531
532        for deletion_vector in [
533            no_deletes,
534            no_deletes_2,
535            delete_some_bitmap,
536            delete_some_set,
537        ] {
538            for has_columns in [false, true] {
539                for with_row_id in [false, true] {
540                    for make_deletions_null in [false, true] {
541                        for frag_id in [0, 1] {
542                            let has_deletions = if let Some(dv) = &deletion_vector {
543                                !matches!(dv.as_ref(), DeletionVector::NoDeletions)
544                            } else {
545                                false
546                            };
547                            if !has_columns && !has_deletions && !with_row_id {
548                                // This is an invalid case and should be prevented upstream,
549                                // no meaningful work is being done!
550                                continue;
551                            }
552                            if make_deletions_null && !with_row_id {
553                                // This is an invalid case and should be prevented upstream
554                                // we cannot make the row_id column null if it isn't present
555                                continue;
556                            }
557
558                            let mut datagen = lance_datagen::gen_batch();
559                            if has_columns {
560                                datagen =
561                                    datagen.col("x", lance_datagen::array::rand::<Int32Type>());
562                            }
563                            // 100 rows across 10 batches of 10 rows
564                            let data = batch_task_stream(
565                                datagen
566                                    .into_reader_stream(RowCount::from(10), BatchCount::from(10))
567                                    .0,
568                            );
569
570                            let config = RowIdAndDeletesConfig {
571                                params: ReadBatchParams::RangeFull,
572                                with_row_id,
573                                with_row_addr: false,
574                                with_row_last_updated_at_version: false,
575                                with_row_created_at_version: false,
576                                deletion_vector: deletion_vector.clone(),
577                                row_id_sequence: None,
578                                last_updated_at_sequence: None,
579                                created_at_sequence: None,
580                                make_deletions_null,
581                                total_num_rows: 100,
582                            };
583                            let stream = super::wrap_with_row_id_and_delete(data, frag_id, config);
584                            let batches = stream
585                                .buffered(1)
586                                .filter_map(|batch| {
587                                    std::future::ready(
588                                        batch
589                                            .map(|batch| {
590                                                if batch.num_rows() == 0 {
591                                                    None
592                                                } else {
593                                                    Some(batch)
594                                                }
595                                            })
596                                            .transpose(),
597                                    )
598                                })
599                                .try_collect::<Vec<_>>()
600                                .await
601                                .unwrap();
602
603                            let total_num_rows =
604                                batches.iter().map(|b| b.num_rows()).sum::<usize>();
605                            let total_num_nulls = if make_deletions_null {
606                                batches
607                                    .iter()
608                                    .map(|b| b[ROW_ID].null_count())
609                                    .sum::<usize>()
610                            } else {
611                                0
612                            };
613                            let total_actually_deleted = total_num_nulls + (100 - total_num_rows);
614
615                            let expected_deletions = match &deletion_vector {
616                                None => 0,
617                                Some(deletion_vector) => match deletion_vector.as_ref() {
618                                    DeletionVector::NoDeletions => 0,
619                                    DeletionVector::Bitmap(b) => b.len() as usize,
620                                    DeletionVector::Set(s) => s.len(),
621                                },
622                            };
623                            assert_eq!(total_actually_deleted, expected_deletions);
624                            if expected_deletions > 0 && with_row_id {
625                                if make_deletions_null {
626                                    // If we make deletions null we get 3 batches of all-null and then
627                                    // a batch of half-null
628                                    assert_eq!(
629                                        batches[3][ROW_ID].as_primitive::<UInt64Type>().value(0),
630                                        u64::from(RowAddress::new_from_parts(frag_id, 30))
631                                    );
632                                    assert_eq!(batches[3][ROW_ID].null_count(), 5);
633                                } else {
634                                    // If we materialize deletions the first row will be 35
635                                    assert_eq!(
636                                        batches[0][ROW_ID].as_primitive::<UInt64Type>().value(0),
637                                        u64::from(RowAddress::new_from_parts(frag_id, 35))
638                                    );
639                                }
640                            }
641                            if !with_row_id {
642                                assert!(batches[0].column_by_name(ROW_ID).is_none());
643                            }
644                        }
645                    }
646                }
647            }
648        }
649    }
650}