lance_datafusion/
datagen.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::sync::Arc;
5
6use arrow_array::RecordBatchReader;
7use datafusion::{
8    execution::SendableRecordBatchStream,
9    physical_plan::{stream::RecordBatchStreamAdapter, ExecutionPlan},
10};
11use datafusion_common::DataFusionError;
12use futures::TryStreamExt;
13use lance_core::Error;
14use lance_datagen::{BatchCount, BatchGeneratorBuilder, ByteCount, RoundingBehavior, RowCount};
15
16use crate::exec::OneShotExec;
17
18pub trait DatafusionDatagenExt {
19    fn into_df_stream(
20        self,
21        batch_size: RowCount,
22        num_batches: BatchCount,
23    ) -> SendableRecordBatchStream;
24
25    fn into_df_stream_bytes(
26        self,
27        batch_size: ByteCount,
28        num_batches: BatchCount,
29        rounding_behavior: RoundingBehavior,
30    ) -> Result<SendableRecordBatchStream, Error>;
31
32    fn into_df_exec(self, batch_size: RowCount, num_batches: BatchCount) -> Arc<dyn ExecutionPlan>;
33}
34
35impl DatafusionDatagenExt for BatchGeneratorBuilder {
36    fn into_df_stream(
37        self,
38        batch_size: RowCount,
39        num_batches: BatchCount,
40    ) -> SendableRecordBatchStream {
41        let (stream, schema) = self.into_reader_stream(batch_size, num_batches);
42        let stream = stream.map_err(DataFusionError::from);
43        Box::pin(RecordBatchStreamAdapter::new(schema, stream))
44    }
45
46    fn into_df_stream_bytes(
47        self,
48        batch_size: ByteCount,
49        num_batches: BatchCount,
50        rounding_behavior: RoundingBehavior,
51    ) -> Result<SendableRecordBatchStream, Error> {
52        let stream = self.into_reader_bytes(batch_size, num_batches, rounding_behavior)?;
53        let schema = stream.schema();
54        let stream = futures::stream::iter(stream).map_err(DataFusionError::from);
55        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
56    }
57
58    fn into_df_exec(self, batch_size: RowCount, num_batches: BatchCount) -> Arc<dyn ExecutionPlan> {
59        let stream = self.into_df_stream(batch_size, num_batches);
60        Arc::new(OneShotExec::new(stream))
61    }
62}