lance_table/utils/
stream.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::sync::Arc;
5
6use arrow_array::{make_array, BooleanArray, RecordBatch, RecordBatchOptions, UInt64Array};
7use arrow_buffer::NullBuffer;
8use futures::{
9    future::BoxFuture,
10    stream::{BoxStream, FuturesOrdered},
11    FutureExt, Stream, StreamExt,
12};
13use lance_arrow::RecordBatchExt;
14use lance_core::{
15    utils::{address::RowAddress, deletion::DeletionVector},
16    Result, ROW_ADDR, ROW_ADDR_FIELD, ROW_ID, ROW_ID_FIELD,
17};
18use lance_io::ReadBatchParams;
19
20use crate::rowids::RowIdSequence;
21
22pub type ReadBatchFut = BoxFuture<'static, Result<RecordBatch>>;
23/// A task, emitted by a file reader, that will produce a batch (of the
24/// given size)
25pub struct ReadBatchTask {
26    pub task: ReadBatchFut,
27    pub num_rows: u32,
28}
29pub type ReadBatchTaskStream = BoxStream<'static, ReadBatchTask>;
30pub type ReadBatchFutStream = BoxStream<'static, ReadBatchFut>;
31
32struct MergeStream {
33    streams: Vec<ReadBatchTaskStream>,
34    next_batch: FuturesOrdered<ReadBatchFut>,
35    next_num_rows: u32,
36    index: usize,
37}
38
39impl MergeStream {
40    fn emit(&mut self) -> ReadBatchTask {
41        let mut iter = std::mem::take(&mut self.next_batch);
42        let task = async move {
43            let mut batch = iter.next().await.unwrap()?;
44            while let Some(next) = iter.next().await {
45                let next = next?;
46                batch = batch.merge(&next)?;
47            }
48            Ok(batch)
49        }
50        .boxed();
51        let num_rows = self.next_num_rows;
52        self.next_num_rows = 0;
53        ReadBatchTask { task, num_rows }
54    }
55}
56
57impl Stream for MergeStream {
58    type Item = ReadBatchTask;
59
60    fn poll_next(
61        mut self: std::pin::Pin<&mut Self>,
62        cx: &mut std::task::Context<'_>,
63    ) -> std::task::Poll<Option<Self::Item>> {
64        loop {
65            let index = self.index;
66            match self.streams[index].poll_next_unpin(cx) {
67                std::task::Poll::Ready(Some(batch_task)) => {
68                    if self.index == 0 {
69                        self.next_num_rows = batch_task.num_rows;
70                    } else {
71                        debug_assert_eq!(self.next_num_rows, batch_task.num_rows);
72                    }
73                    self.next_batch.push_back(batch_task.task);
74                    self.index += 1;
75                    if self.index == self.streams.len() {
76                        self.index = 0;
77                        let next_batch = self.emit();
78                        return std::task::Poll::Ready(Some(next_batch));
79                    }
80                }
81                std::task::Poll::Ready(None) => {
82                    return std::task::Poll::Ready(None);
83                }
84                std::task::Poll::Pending => {
85                    return std::task::Poll::Pending;
86                }
87            }
88        }
89    }
90}
91
92/// Given multiple streams of batch tasks, merge them into a single stream
93///
94/// This pulls one batch from each stream and then combines the columns from
95/// all of the batches into a single batch.  The order of the batches in the
96/// streams is maintained and the merged batch columns will be in order from
97/// first to last stream.
98///
99/// This stream ends as soon as any of the input streams ends (we do not
100/// verify that the other input streams are finished as well)
101///
102/// This will panic if any of the input streams return a batch with a different
103/// number of rows than the first stream.
104pub fn merge_streams(streams: Vec<ReadBatchTaskStream>) -> ReadBatchTaskStream {
105    MergeStream {
106        streams,
107        next_batch: FuturesOrdered::new(),
108        next_num_rows: 0,
109        index: 0,
110    }
111    .boxed()
112}
113
114/// Apply a mask to the batch, where rows are "deleted" by the _rowid column null.
115///
116/// This is used partly as a performance optimization (cheaper to null than to filter)
117/// but also because there are cases where we want to load the physical rows.  For example,
118/// we may be replacing a column based on some UDF and we want to provide a value for the
119/// deleted rows to ensure the fragments are aligned.
120fn apply_deletions_as_nulls(batch: RecordBatch, mask: &BooleanArray) -> Result<RecordBatch> {
121    // Transform mask into null buffer. Null means deleted, though note that
122    // null buffers are actually validity buffers, so True means not null
123    // and thus not deleted.
124    let mask_buffer = NullBuffer::new(mask.values().clone());
125
126    match mask_buffer.null_count() {
127        // No rows are deleted
128        0 => return Ok(batch),
129        _ => {}
130    }
131
132    // For each column convert to data
133    let new_columns = batch
134        .schema()
135        .fields()
136        .iter()
137        .zip(batch.columns())
138        .map(|(field, col)| {
139            if field.name() == ROW_ID || field.name() == ROW_ADDR {
140                let col_data = col.to_data();
141                // If it already has a validity bitmap, then AND it with the mask.
142                // Otherwise, use the boolean buffer as the mask.
143                let null_buffer = NullBuffer::union(col_data.nulls(), Some(&mask_buffer));
144
145                Ok(col_data
146                    .into_builder()
147                    .null_bit_buffer(null_buffer.map(|b| b.buffer().clone()))
148                    .build()
149                    .map(make_array)?)
150            } else {
151                Ok(col.clone())
152            }
153        })
154        .collect::<Result<Vec<_>>>()?;
155
156    Ok(RecordBatch::try_new_with_options(
157        batch.schema(),
158        new_columns,
159        &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())),
160    )?)
161}
162
163/// Configuration needed to apply row ids and deletions to a batch
164pub struct RowIdAndDeletesConfig {
165    /// The row ids that were requested
166    pub params: ReadBatchParams,
167    /// Whether to include the row id column in the final batch
168    pub with_row_id: bool,
169    /// Whether to include the row address column in the final batch
170    pub with_row_addr: bool,
171    /// An optional deletion vector to apply to the batch
172    pub deletion_vector: Option<Arc<DeletionVector>>,
173    /// An optional row id sequence to use for the row id column.
174    pub row_id_sequence: Option<Arc<RowIdSequence>>,
175    /// Whether to make deleted rows null instead of filtering them out
176    pub make_deletions_null: bool,
177    /// The total number of rows that will be loaded
178    ///
179    /// This is needed to convert ReadbatchParams::RangeTo into a valid range
180    pub total_num_rows: u32,
181}
182
183pub fn apply_row_id_and_deletes(
184    batch: RecordBatch,
185    batch_offset: u32,
186    fragment_id: u32,
187    config: &RowIdAndDeletesConfig,
188) -> Result<RecordBatch> {
189    let mut deletion_vector = config.deletion_vector.as_ref();
190    // Convert Some(NoDeletions) into None to simplify logic below
191    if let Some(deletion_vector_inner) = deletion_vector {
192        if matches!(deletion_vector_inner.as_ref(), DeletionVector::NoDeletions) {
193            deletion_vector = None;
194        }
195    }
196    let has_deletions = deletion_vector.is_some();
197    debug_assert!(
198        batch.num_columns() > 0 || config.with_row_id || config.with_row_addr || has_deletions
199    );
200
201    // If row id sequence is None, then row id IS row address.
202    let should_fetch_row_addr = config.with_row_addr
203        || (config.with_row_id && config.row_id_sequence.is_none())
204        || has_deletions;
205
206    let num_rows = batch.num_rows() as u32;
207
208    let row_addrs = if should_fetch_row_addr {
209        let row_offsets_in_batch = &config
210            .params
211            .slice(batch_offset as usize, num_rows as usize)
212            .unwrap()
213            .to_offsets()
214            .unwrap();
215        let row_addrs: UInt64Array = row_offsets_in_batch
216            .values()
217            .iter()
218            .map(|row_offset| u64::from(RowAddress::new_from_parts(fragment_id, *row_offset)))
219            .collect();
220
221        Some(Arc::new(row_addrs))
222    } else {
223        None
224    };
225
226    let row_ids = if config.with_row_id {
227        if let Some(row_id_sequence) = &config.row_id_sequence {
228            let row_ids = row_id_sequence
229                .slice(batch_offset as usize, num_rows as usize)
230                .iter()
231                .collect::<UInt64Array>();
232            Some(Arc::new(row_ids))
233        } else {
234            // If we don't have a row id sequence, can assume the row ids are
235            // the same as the row addresses.
236            row_addrs.clone()
237        }
238    } else {
239        None
240    };
241
242    let span = tracing::span!(tracing::Level::DEBUG, "apply_deletions");
243    let _enter = span.enter();
244    let deletion_mask = deletion_vector.and_then(|v| {
245        let row_addrs: &[u64] = row_addrs.as_ref().unwrap().values();
246        v.build_predicate(row_addrs.iter())
247    });
248
249    let batch = if config.with_row_id {
250        let row_id_arr = row_ids.unwrap();
251        batch.try_with_column(ROW_ID_FIELD.clone(), row_id_arr)?
252    } else {
253        batch
254    };
255
256    let batch = if config.with_row_addr {
257        let row_addr_arr = row_addrs.unwrap();
258        batch.try_with_column(ROW_ADDR_FIELD.clone(), row_addr_arr)?
259    } else {
260        batch
261    };
262
263    match (deletion_mask, config.make_deletions_null) {
264        (None, _) => Ok(batch),
265        (Some(mask), false) => Ok(arrow::compute::filter_record_batch(&batch, &mask)?),
266        (Some(mask), true) => Ok(apply_deletions_as_nulls(batch, &mask)?),
267    }
268}
269
270/// Given a stream of batch tasks this function will add a row ids column (if requested)
271/// and also apply a deletions vector to the batch.
272///
273/// This converts from BatchTaskStream to BatchFutStream because, if we are applying a
274/// deletion vector, it is impossible to know how many output rows we will have.
275pub fn wrap_with_row_id_and_delete(
276    stream: ReadBatchTaskStream,
277    fragment_id: u32,
278    config: RowIdAndDeletesConfig,
279) -> ReadBatchFutStream {
280    let config = Arc::new(config);
281    let mut offset = 0;
282    stream
283        .map(move |batch_task| {
284            let config = config.clone();
285            let this_offset = offset;
286            let num_rows = batch_task.num_rows;
287            offset += num_rows;
288            let task = batch_task.task;
289            async move {
290                let batch = task.await?;
291                apply_row_id_and_deletes(batch, this_offset, fragment_id, config.as_ref())
292            }
293            .boxed()
294        })
295        .boxed()
296}
297
298#[cfg(test)]
299mod tests {
300    use std::sync::Arc;
301
302    use arrow::{array::AsArray, datatypes::UInt64Type};
303    use arrow_array::{types::Int32Type, RecordBatch, UInt32Array};
304    use arrow_schema::ArrowError;
305    use futures::{stream::BoxStream, FutureExt, StreamExt, TryStreamExt};
306    use lance_core::{
307        utils::{address::RowAddress, deletion::DeletionVector},
308        ROW_ID,
309    };
310    use lance_datagen::{BatchCount, RowCount};
311    use lance_io::{stream::arrow_stream_to_lance_stream, ReadBatchParams};
312    use roaring::RoaringBitmap;
313
314    use crate::utils::stream::ReadBatchTask;
315
316    use super::RowIdAndDeletesConfig;
317
318    fn batch_task_stream(
319        datagen_stream: BoxStream<'static, std::result::Result<RecordBatch, ArrowError>>,
320    ) -> super::ReadBatchTaskStream {
321        arrow_stream_to_lance_stream(datagen_stream)
322            .map(|batch| ReadBatchTask {
323                num_rows: batch.as_ref().unwrap().num_rows() as u32,
324                task: std::future::ready(batch).boxed(),
325            })
326            .boxed()
327    }
328
329    #[tokio::test]
330    async fn test_basic_zip() {
331        let left = batch_task_stream(
332            lance_datagen::gen()
333                .col("x", lance_datagen::array::step::<Int32Type>())
334                .into_reader_stream(RowCount::from(100), BatchCount::from(10))
335                .0,
336        );
337        let right = batch_task_stream(
338            lance_datagen::gen()
339                .col("y", lance_datagen::array::step::<Int32Type>())
340                .into_reader_stream(RowCount::from(100), BatchCount::from(10))
341                .0,
342        );
343
344        let merged = super::merge_streams(vec![left, right])
345            .map(|batch_task| batch_task.task)
346            .buffered(1)
347            .try_collect::<Vec<_>>()
348            .await
349            .unwrap();
350
351        let expected = lance_datagen::gen()
352            .col("x", lance_datagen::array::step::<Int32Type>())
353            .col("y", lance_datagen::array::step::<Int32Type>())
354            .into_reader_rows(RowCount::from(100), BatchCount::from(10))
355            .collect::<Result<Vec<_>, ArrowError>>()
356            .unwrap();
357        assert_eq!(merged, expected);
358    }
359
360    async fn check_row_id(params: ReadBatchParams, expected: impl IntoIterator<Item = u32>) {
361        let expected = Vec::from_iter(expected);
362
363        for has_columns in [false, true] {
364            for fragment_id in [0, 10] {
365                // 100 rows across 10 batches of 10 rows
366                let mut datagen = lance_datagen::gen();
367                if has_columns {
368                    datagen = datagen.col("x", lance_datagen::array::rand::<Int32Type>());
369                }
370                let data = batch_task_stream(
371                    datagen
372                        .into_reader_stream(RowCount::from(10), BatchCount::from(10))
373                        .0,
374                );
375
376                let config = RowIdAndDeletesConfig {
377                    params: params.clone(),
378                    with_row_id: true,
379                    with_row_addr: false,
380                    deletion_vector: None,
381                    row_id_sequence: None,
382                    make_deletions_null: false,
383                    total_num_rows: 100,
384                };
385                let stream = super::wrap_with_row_id_and_delete(data, fragment_id, config);
386                let batches = stream.buffered(1).try_collect::<Vec<_>>().await.unwrap();
387
388                let mut offset = 0;
389                let expected = expected.clone();
390                for batch in batches {
391                    let actual_row_ids =
392                        batch[ROW_ID].as_primitive::<UInt64Type>().values().to_vec();
393                    let expected_row_ids = expected[offset..offset + 10]
394                        .iter()
395                        .map(|row_offset| {
396                            RowAddress::new_from_parts(fragment_id, *row_offset).into()
397                        })
398                        .collect::<Vec<u64>>();
399                    assert_eq!(actual_row_ids, expected_row_ids);
400                    offset += batch.num_rows();
401                }
402            }
403        }
404    }
405
406    #[tokio::test]
407    async fn test_row_id() {
408        let some_indices = (0..100).rev().collect::<Vec<u32>>();
409        let some_indices_arr = UInt32Array::from(some_indices.clone());
410        check_row_id(ReadBatchParams::RangeFull, 0..100).await;
411        check_row_id(ReadBatchParams::Indices(some_indices_arr), some_indices).await;
412        check_row_id(ReadBatchParams::Range(1000..1100), 1000..1100).await;
413        check_row_id(
414            ReadBatchParams::RangeFrom(std::ops::RangeFrom { start: 1000 }),
415            1000..1100,
416        )
417        .await;
418        check_row_id(
419            ReadBatchParams::RangeTo(std::ops::RangeTo { end: 1000 }),
420            0..100,
421        )
422        .await;
423    }
424
425    #[tokio::test]
426    async fn test_deletes() {
427        let no_deletes: Option<Arc<DeletionVector>> = None;
428        let no_deletes_2 = Some(Arc::new(DeletionVector::NoDeletions));
429        let delete_some_bitmap = Some(Arc::new(DeletionVector::Bitmap(RoaringBitmap::from_iter(
430            0..35,
431        ))));
432        let delete_some_set = Some(Arc::new(DeletionVector::Set((0..35).collect())));
433
434        for deletion_vector in [
435            no_deletes,
436            no_deletes_2,
437            delete_some_bitmap,
438            delete_some_set,
439        ] {
440            for has_columns in [false, true] {
441                for with_row_id in [false, true] {
442                    for make_deletions_null in [false, true] {
443                        for frag_id in [0, 1] {
444                            let has_deletions = if let Some(dv) = &deletion_vector {
445                                !matches!(dv.as_ref(), DeletionVector::NoDeletions)
446                            } else {
447                                false
448                            };
449                            if !has_columns && !has_deletions && !with_row_id {
450                                // This is an invalid case and should be prevented upstream,
451                                // no meaningful work is being done!
452                                continue;
453                            }
454                            if make_deletions_null && !with_row_id {
455                                // This is an invalid case and should be prevented upstream
456                                // we cannot make the row_id column null if it isn't present
457                                continue;
458                            }
459
460                            let mut datagen = lance_datagen::gen();
461                            if has_columns {
462                                datagen =
463                                    datagen.col("x", lance_datagen::array::rand::<Int32Type>());
464                            }
465                            // 100 rows across 10 batches of 10 rows
466                            let data = batch_task_stream(
467                                datagen
468                                    .into_reader_stream(RowCount::from(10), BatchCount::from(10))
469                                    .0,
470                            );
471
472                            let config = RowIdAndDeletesConfig {
473                                params: ReadBatchParams::RangeFull,
474                                with_row_id,
475                                with_row_addr: false,
476                                deletion_vector: deletion_vector.clone(),
477                                row_id_sequence: None,
478                                make_deletions_null,
479                                total_num_rows: 100,
480                            };
481                            let stream = super::wrap_with_row_id_and_delete(data, frag_id, config);
482                            let batches = stream
483                                .buffered(1)
484                                .filter_map(|batch| {
485                                    std::future::ready(
486                                        batch
487                                            .map(|batch| {
488                                                if batch.num_rows() == 0 {
489                                                    None
490                                                } else {
491                                                    Some(batch)
492                                                }
493                                            })
494                                            .transpose(),
495                                    )
496                                })
497                                .try_collect::<Vec<_>>()
498                                .await
499                                .unwrap();
500
501                            let total_num_rows =
502                                batches.iter().map(|b| b.num_rows()).sum::<usize>();
503                            let total_num_nulls = if make_deletions_null {
504                                batches
505                                    .iter()
506                                    .map(|b| b[ROW_ID].null_count())
507                                    .sum::<usize>()
508                            } else {
509                                0
510                            };
511                            let total_actually_deleted = total_num_nulls + (100 - total_num_rows);
512
513                            let expected_deletions = match &deletion_vector {
514                                None => 0,
515                                Some(deletion_vector) => match deletion_vector.as_ref() {
516                                    DeletionVector::NoDeletions => 0,
517                                    DeletionVector::Bitmap(b) => b.len() as usize,
518                                    DeletionVector::Set(s) => s.len(),
519                                },
520                            };
521                            assert_eq!(total_actually_deleted, expected_deletions);
522                            if expected_deletions > 0 && with_row_id {
523                                if make_deletions_null {
524                                    // If we make deletions null we get 3 batches of all-null and then
525                                    // a batch of half-null
526                                    assert_eq!(
527                                        batches[3][ROW_ID].as_primitive::<UInt64Type>().value(0),
528                                        u64::from(RowAddress::new_from_parts(frag_id, 30))
529                                    );
530                                    assert_eq!(batches[3][ROW_ID].null_count(), 5);
531                                } else {
532                                    // If we materialize deletions the first row will be 35
533                                    assert_eq!(
534                                        batches[0][ROW_ID].as_primitive::<UInt64Type>().value(0),
535                                        u64::from(RowAddress::new_from_parts(frag_id, 35))
536                                    );
537                                }
538                            }
539                            if !with_row_id {
540                                assert!(batches[0].column_by_name(ROW_ID).is_none());
541                            }
542                        }
543                    }
544                }
545            }
546        }
547    }
548}