lance_table/utils/
stream.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::sync::Arc;
5
6use arrow_array::{make_array, BooleanArray, RecordBatch, RecordBatchOptions, UInt64Array};
7use arrow_buffer::NullBuffer;
8use futures::{
9    future::BoxFuture,
10    stream::{BoxStream, FuturesOrdered},
11    FutureExt, Stream, StreamExt,
12};
13use lance_arrow::RecordBatchExt;
14use lance_core::{
15    utils::{address::RowAddress, deletion::DeletionVector},
16    Result, ROW_ADDR, ROW_ADDR_FIELD, ROW_ID, ROW_ID_FIELD,
17};
18use lance_io::ReadBatchParams;
19use tracing::{instrument, Instrument};
20
21use crate::rowids::RowIdSequence;
22
23pub type ReadBatchFut = BoxFuture<'static, Result<RecordBatch>>;
24/// A task, emitted by a file reader, that will produce a batch (of the
25/// given size)
26pub struct ReadBatchTask {
27    pub task: ReadBatchFut,
28    pub num_rows: u32,
29}
30pub type ReadBatchTaskStream = BoxStream<'static, ReadBatchTask>;
31pub type ReadBatchFutStream = BoxStream<'static, ReadBatchFut>;
32
33struct MergeStream {
34    streams: Vec<ReadBatchTaskStream>,
35    next_batch: FuturesOrdered<ReadBatchFut>,
36    next_num_rows: u32,
37    index: usize,
38}
39
40impl MergeStream {
41    fn emit(&mut self) -> ReadBatchTask {
42        let mut iter = std::mem::take(&mut self.next_batch);
43        let task = async move {
44            let mut batch = iter.next().await.unwrap()?;
45            while let Some(next) = iter.next().await {
46                let next = next?;
47                batch = batch.merge(&next)?;
48            }
49            Ok(batch)
50        }
51        .boxed();
52        let num_rows = self.next_num_rows;
53        self.next_num_rows = 0;
54        ReadBatchTask { task, num_rows }
55    }
56}
57
58impl Stream for MergeStream {
59    type Item = ReadBatchTask;
60
61    fn poll_next(
62        mut self: std::pin::Pin<&mut Self>,
63        cx: &mut std::task::Context<'_>,
64    ) -> std::task::Poll<Option<Self::Item>> {
65        loop {
66            let index = self.index;
67            match self.streams[index].poll_next_unpin(cx) {
68                std::task::Poll::Ready(Some(batch_task)) => {
69                    if self.index == 0 {
70                        self.next_num_rows = batch_task.num_rows;
71                    } else {
72                        debug_assert_eq!(self.next_num_rows, batch_task.num_rows);
73                    }
74                    self.next_batch.push_back(batch_task.task);
75                    self.index += 1;
76                    if self.index == self.streams.len() {
77                        self.index = 0;
78                        let next_batch = self.emit();
79                        return std::task::Poll::Ready(Some(next_batch));
80                    }
81                }
82                std::task::Poll::Ready(None) => {
83                    return std::task::Poll::Ready(None);
84                }
85                std::task::Poll::Pending => {
86                    return std::task::Poll::Pending;
87                }
88            }
89        }
90    }
91}
92
93/// Given multiple streams of batch tasks, merge them into a single stream
94///
95/// This pulls one batch from each stream and then combines the columns from
96/// all of the batches into a single batch.  The order of the batches in the
97/// streams is maintained and the merged batch columns will be in order from
98/// first to last stream.
99///
100/// This stream ends as soon as any of the input streams ends (we do not
101/// verify that the other input streams are finished as well)
102///
103/// This will panic if any of the input streams return a batch with a different
104/// number of rows than the first stream.
105pub fn merge_streams(streams: Vec<ReadBatchTaskStream>) -> ReadBatchTaskStream {
106    MergeStream {
107        streams,
108        next_batch: FuturesOrdered::new(),
109        next_num_rows: 0,
110        index: 0,
111    }
112    .boxed()
113}
114
115/// Apply a mask to the batch, where rows are "deleted" by the _rowid column null.
116///
117/// This is used partly as a performance optimization (cheaper to null than to filter)
118/// but also because there are cases where we want to load the physical rows.  For example,
119/// we may be replacing a column based on some UDF and we want to provide a value for the
120/// deleted rows to ensure the fragments are aligned.
121fn apply_deletions_as_nulls(batch: RecordBatch, mask: &BooleanArray) -> Result<RecordBatch> {
122    // Transform mask into null buffer. Null means deleted, though note that
123    // null buffers are actually validity buffers, so True means not null
124    // and thus not deleted.
125    let mask_buffer = NullBuffer::new(mask.values().clone());
126
127    if mask_buffer.null_count() == 0 {
128        // No rows are deleted
129        return Ok(batch);
130    }
131
132    // For each column convert to data
133    let new_columns = batch
134        .schema()
135        .fields()
136        .iter()
137        .zip(batch.columns())
138        .map(|(field, col)| {
139            if field.name() == ROW_ID || field.name() == ROW_ADDR {
140                let col_data = col.to_data();
141                // If it already has a validity bitmap, then AND it with the mask.
142                // Otherwise, use the boolean buffer as the mask.
143                let null_buffer = NullBuffer::union(col_data.nulls(), Some(&mask_buffer));
144
145                Ok(col_data
146                    .into_builder()
147                    .null_bit_buffer(null_buffer.map(|b| b.buffer().clone()))
148                    .build()
149                    .map(make_array)?)
150            } else {
151                Ok(col.clone())
152            }
153        })
154        .collect::<Result<Vec<_>>>()?;
155
156    Ok(RecordBatch::try_new_with_options(
157        batch.schema(),
158        new_columns,
159        &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())),
160    )?)
161}
162
163/// Configuration needed to apply row ids and deletions to a batch
164#[derive(Debug)]
165pub struct RowIdAndDeletesConfig {
166    /// The row ids that were requested
167    pub params: ReadBatchParams,
168    /// Whether to include the row id column in the final batch
169    pub with_row_id: bool,
170    /// Whether to include the row address column in the final batch
171    pub with_row_addr: bool,
172    /// An optional deletion vector to apply to the batch
173    pub deletion_vector: Option<Arc<DeletionVector>>,
174    /// An optional row id sequence to use for the row id column.
175    pub row_id_sequence: Option<Arc<RowIdSequence>>,
176    /// Whether to make deleted rows null instead of filtering them out
177    pub make_deletions_null: bool,
178    /// The total number of rows that will be loaded
179    ///
180    /// This is needed to convert ReadbatchParams::RangeTo into a valid range
181    pub total_num_rows: u32,
182}
183
184#[instrument(level = "debug", skip_all)]
185pub fn apply_row_id_and_deletes(
186    batch: RecordBatch,
187    batch_offset: u32,
188    fragment_id: u32,
189    config: &RowIdAndDeletesConfig,
190) -> Result<RecordBatch> {
191    let mut deletion_vector = config.deletion_vector.as_ref();
192    // Convert Some(NoDeletions) into None to simplify logic below
193    if let Some(deletion_vector_inner) = deletion_vector {
194        if matches!(deletion_vector_inner.as_ref(), DeletionVector::NoDeletions) {
195            deletion_vector = None;
196        }
197    }
198    let has_deletions = deletion_vector.is_some();
199    debug_assert!(
200        batch.num_columns() > 0 || config.with_row_id || config.with_row_addr || has_deletions
201    );
202
203    // If row id sequence is None, then row id IS row address.
204    let should_fetch_row_addr = config.with_row_addr
205        || (config.with_row_id && config.row_id_sequence.is_none())
206        || has_deletions;
207
208    let num_rows = batch.num_rows() as u32;
209
210    let row_addrs =
211        if should_fetch_row_addr {
212            let _rowaddrs = tracing::span!(tracing::Level::DEBUG, "fetch_row_addrs").entered();
213            let mut row_addrs = Vec::with_capacity(num_rows as usize);
214            for offset_range in config
215                .params
216                .slice(batch_offset as usize, num_rows as usize)
217                .unwrap()
218                .iter_offset_ranges()?
219            {
220                row_addrs.extend(offset_range.map(|row_offset| {
221                    u64::from(RowAddress::new_from_parts(fragment_id, row_offset))
222                }));
223            }
224
225            Some(Arc::new(UInt64Array::from(row_addrs)))
226        } else {
227            None
228        };
229
230    let row_ids = if config.with_row_id {
231        let _rowids = tracing::span!(tracing::Level::DEBUG, "fetch_row_ids").entered();
232        if let Some(row_id_sequence) = &config.row_id_sequence {
233            let selection = config
234                .params
235                .slice(batch_offset as usize, num_rows as usize)
236                .unwrap()
237                .to_ranges()
238                .unwrap();
239            let row_ids = row_id_sequence
240                .select(
241                    selection
242                        .iter()
243                        .flat_map(|r| r.start as usize..r.end as usize),
244                )
245                .collect::<UInt64Array>();
246            Some(Arc::new(row_ids))
247        } else {
248            // If we don't have a row id sequence, can assume the row ids are
249            // the same as the row addresses.
250            row_addrs.clone()
251        }
252    } else {
253        None
254    };
255
256    let span = tracing::span!(tracing::Level::DEBUG, "apply_deletions");
257    let _enter = span.enter();
258    let deletion_mask = deletion_vector.and_then(|v| {
259        let row_addrs: &[u64] = row_addrs.as_ref().unwrap().values();
260        v.build_predicate(row_addrs.iter())
261    });
262
263    let batch = if config.with_row_id {
264        let row_id_arr = row_ids.unwrap();
265        batch.try_with_column(ROW_ID_FIELD.clone(), row_id_arr)?
266    } else {
267        batch
268    };
269
270    let batch = if config.with_row_addr {
271        let row_addr_arr = row_addrs.unwrap();
272        batch.try_with_column(ROW_ADDR_FIELD.clone(), row_addr_arr)?
273    } else {
274        batch
275    };
276
277    match (deletion_mask, config.make_deletions_null) {
278        (None, _) => Ok(batch),
279        (Some(mask), false) => Ok(arrow::compute::filter_record_batch(&batch, &mask)?),
280        (Some(mask), true) => Ok(apply_deletions_as_nulls(batch, &mask)?),
281    }
282}
283
284/// Given a stream of batch tasks this function will add a row ids column (if requested)
285/// and also apply a deletions vector to the batch.
286///
287/// This converts from BatchTaskStream to BatchFutStream because, if we are applying a
288/// deletion vector, it is impossible to know how many output rows we will have.
289pub fn wrap_with_row_id_and_delete(
290    stream: ReadBatchTaskStream,
291    fragment_id: u32,
292    config: RowIdAndDeletesConfig,
293) -> ReadBatchFutStream {
294    let config = Arc::new(config);
295    let mut offset = 0;
296    stream
297        .map(move |batch_task| {
298            let config = config.clone();
299            let this_offset = offset;
300            let num_rows = batch_task.num_rows;
301            offset += num_rows;
302            let task = batch_task.task;
303            tokio::spawn(
304                async move {
305                    let batch = task.await?;
306                    apply_row_id_and_deletes(batch, this_offset, fragment_id, config.as_ref())
307                }
308                .in_current_span(),
309            )
310            .map(|join_wrapper| join_wrapper.unwrap())
311            .boxed()
312        })
313        .boxed()
314}
315
316#[cfg(test)]
317mod tests {
318    use std::sync::Arc;
319
320    use arrow::{array::AsArray, datatypes::UInt64Type};
321    use arrow_array::{types::Int32Type, RecordBatch, UInt32Array};
322    use arrow_schema::ArrowError;
323    use futures::{stream::BoxStream, FutureExt, StreamExt, TryStreamExt};
324    use lance_core::{
325        utils::{address::RowAddress, deletion::DeletionVector},
326        ROW_ID,
327    };
328    use lance_datagen::{BatchCount, RowCount};
329    use lance_io::{stream::arrow_stream_to_lance_stream, ReadBatchParams};
330    use roaring::RoaringBitmap;
331
332    use crate::utils::stream::ReadBatchTask;
333
334    use super::RowIdAndDeletesConfig;
335
336    fn batch_task_stream(
337        datagen_stream: BoxStream<'static, std::result::Result<RecordBatch, ArrowError>>,
338    ) -> super::ReadBatchTaskStream {
339        arrow_stream_to_lance_stream(datagen_stream)
340            .map(|batch| ReadBatchTask {
341                num_rows: batch.as_ref().unwrap().num_rows() as u32,
342                task: std::future::ready(batch).boxed(),
343            })
344            .boxed()
345    }
346
347    #[tokio::test]
348    async fn test_basic_zip() {
349        let left = batch_task_stream(
350            lance_datagen::gen_batch()
351                .col("x", lance_datagen::array::step::<Int32Type>())
352                .into_reader_stream(RowCount::from(100), BatchCount::from(10))
353                .0,
354        );
355        let right = batch_task_stream(
356            lance_datagen::gen_batch()
357                .col("y", lance_datagen::array::step::<Int32Type>())
358                .into_reader_stream(RowCount::from(100), BatchCount::from(10))
359                .0,
360        );
361
362        let merged = super::merge_streams(vec![left, right])
363            .map(|batch_task| batch_task.task)
364            .buffered(1)
365            .try_collect::<Vec<_>>()
366            .await
367            .unwrap();
368
369        let expected = lance_datagen::gen_batch()
370            .col("x", lance_datagen::array::step::<Int32Type>())
371            .col("y", lance_datagen::array::step::<Int32Type>())
372            .into_reader_rows(RowCount::from(100), BatchCount::from(10))
373            .collect::<Result<Vec<_>, ArrowError>>()
374            .unwrap();
375        assert_eq!(merged, expected);
376    }
377
378    async fn check_row_id(params: ReadBatchParams, expected: impl IntoIterator<Item = u32>) {
379        let expected = Vec::from_iter(expected);
380
381        for has_columns in [false, true] {
382            for fragment_id in [0, 10] {
383                // 100 rows across 10 batches of 10 rows
384                let mut datagen = lance_datagen::gen_batch();
385                if has_columns {
386                    datagen = datagen.col("x", lance_datagen::array::rand::<Int32Type>());
387                }
388                let data = batch_task_stream(
389                    datagen
390                        .into_reader_stream(RowCount::from(10), BatchCount::from(10))
391                        .0,
392                );
393
394                let config = RowIdAndDeletesConfig {
395                    params: params.clone(),
396                    with_row_id: true,
397                    with_row_addr: false,
398                    deletion_vector: None,
399                    row_id_sequence: None,
400                    make_deletions_null: false,
401                    total_num_rows: 100,
402                };
403                let stream = super::wrap_with_row_id_and_delete(data, fragment_id, config);
404                let batches = stream.buffered(1).try_collect::<Vec<_>>().await.unwrap();
405
406                let mut offset = 0;
407                let expected = expected.clone();
408                for batch in batches {
409                    let actual_row_ids =
410                        batch[ROW_ID].as_primitive::<UInt64Type>().values().to_vec();
411                    let expected_row_ids = expected[offset..offset + 10]
412                        .iter()
413                        .map(|row_offset| {
414                            RowAddress::new_from_parts(fragment_id, *row_offset).into()
415                        })
416                        .collect::<Vec<u64>>();
417                    assert_eq!(actual_row_ids, expected_row_ids);
418                    offset += batch.num_rows();
419                }
420            }
421        }
422    }
423
424    #[tokio::test]
425    async fn test_row_id() {
426        let some_indices = (0..100).rev().collect::<Vec<u32>>();
427        let some_indices_arr = UInt32Array::from(some_indices.clone());
428        check_row_id(ReadBatchParams::RangeFull, 0..100).await;
429        check_row_id(ReadBatchParams::Indices(some_indices_arr), some_indices).await;
430        check_row_id(ReadBatchParams::Range(1000..1100), 1000..1100).await;
431        check_row_id(
432            ReadBatchParams::RangeFrom(std::ops::RangeFrom { start: 1000 }),
433            1000..1100,
434        )
435        .await;
436        check_row_id(
437            ReadBatchParams::RangeTo(std::ops::RangeTo { end: 1000 }),
438            0..100,
439        )
440        .await;
441    }
442
443    #[tokio::test]
444    async fn test_deletes() {
445        let no_deletes: Option<Arc<DeletionVector>> = None;
446        let no_deletes_2 = Some(Arc::new(DeletionVector::NoDeletions));
447        let delete_some_bitmap = Some(Arc::new(DeletionVector::Bitmap(RoaringBitmap::from_iter(
448            0..35,
449        ))));
450        let delete_some_set = Some(Arc::new(DeletionVector::Set((0..35).collect())));
451
452        for deletion_vector in [
453            no_deletes,
454            no_deletes_2,
455            delete_some_bitmap,
456            delete_some_set,
457        ] {
458            for has_columns in [false, true] {
459                for with_row_id in [false, true] {
460                    for make_deletions_null in [false, true] {
461                        for frag_id in [0, 1] {
462                            let has_deletions = if let Some(dv) = &deletion_vector {
463                                !matches!(dv.as_ref(), DeletionVector::NoDeletions)
464                            } else {
465                                false
466                            };
467                            if !has_columns && !has_deletions && !with_row_id {
468                                // This is an invalid case and should be prevented upstream,
469                                // no meaningful work is being done!
470                                continue;
471                            }
472                            if make_deletions_null && !with_row_id {
473                                // This is an invalid case and should be prevented upstream
474                                // we cannot make the row_id column null if it isn't present
475                                continue;
476                            }
477
478                            let mut datagen = lance_datagen::gen_batch();
479                            if has_columns {
480                                datagen =
481                                    datagen.col("x", lance_datagen::array::rand::<Int32Type>());
482                            }
483                            // 100 rows across 10 batches of 10 rows
484                            let data = batch_task_stream(
485                                datagen
486                                    .into_reader_stream(RowCount::from(10), BatchCount::from(10))
487                                    .0,
488                            );
489
490                            let config = RowIdAndDeletesConfig {
491                                params: ReadBatchParams::RangeFull,
492                                with_row_id,
493                                with_row_addr: false,
494                                deletion_vector: deletion_vector.clone(),
495                                row_id_sequence: None,
496                                make_deletions_null,
497                                total_num_rows: 100,
498                            };
499                            let stream = super::wrap_with_row_id_and_delete(data, frag_id, config);
500                            let batches = stream
501                                .buffered(1)
502                                .filter_map(|batch| {
503                                    std::future::ready(
504                                        batch
505                                            .map(|batch| {
506                                                if batch.num_rows() == 0 {
507                                                    None
508                                                } else {
509                                                    Some(batch)
510                                                }
511                                            })
512                                            .transpose(),
513                                    )
514                                })
515                                .try_collect::<Vec<_>>()
516                                .await
517                                .unwrap();
518
519                            let total_num_rows =
520                                batches.iter().map(|b| b.num_rows()).sum::<usize>();
521                            let total_num_nulls = if make_deletions_null {
522                                batches
523                                    .iter()
524                                    .map(|b| b[ROW_ID].null_count())
525                                    .sum::<usize>()
526                            } else {
527                                0
528                            };
529                            let total_actually_deleted = total_num_nulls + (100 - total_num_rows);
530
531                            let expected_deletions = match &deletion_vector {
532                                None => 0,
533                                Some(deletion_vector) => match deletion_vector.as_ref() {
534                                    DeletionVector::NoDeletions => 0,
535                                    DeletionVector::Bitmap(b) => b.len() as usize,
536                                    DeletionVector::Set(s) => s.len(),
537                                },
538                            };
539                            assert_eq!(total_actually_deleted, expected_deletions);
540                            if expected_deletions > 0 && with_row_id {
541                                if make_deletions_null {
542                                    // If we make deletions null we get 3 batches of all-null and then
543                                    // a batch of half-null
544                                    assert_eq!(
545                                        batches[3][ROW_ID].as_primitive::<UInt64Type>().value(0),
546                                        u64::from(RowAddress::new_from_parts(frag_id, 30))
547                                    );
548                                    assert_eq!(batches[3][ROW_ID].null_count(), 5);
549                                } else {
550                                    // If we materialize deletions the first row will be 35
551                                    assert_eq!(
552                                        batches[0][ROW_ID].as_primitive::<UInt64Type>().value(0),
553                                        u64::from(RowAddress::new_from_parts(frag_id, 35))
554                                    );
555                                }
556                            }
557                            if !with_row_id {
558                                assert!(batches[0].column_by_name(ROW_ID).is_none());
559                            }
560                        }
561                    }
562                }
563            }
564        }
565    }
566}