lance_datafusion/
datagen.rs1use std::sync::Arc;
5
6use arrow_array::RecordBatchReader;
7use datafusion::{
8 execution::SendableRecordBatchStream,
9 physical_plan::{stream::RecordBatchStreamAdapter, ExecutionPlan},
10};
11use datafusion_common::DataFusionError;
12use futures::TryStreamExt;
13use lance_core::Error;
14use lance_datagen::{BatchCount, BatchGeneratorBuilder, ByteCount, RoundingBehavior, RowCount};
15
16use crate::exec::OneShotExec;
17
18pub trait DatafusionDatagenExt {
19 fn into_df_stream(
20 self,
21 batch_size: RowCount,
22 num_batches: BatchCount,
23 ) -> SendableRecordBatchStream;
24
25 fn into_df_stream_bytes(
26 self,
27 batch_size: ByteCount,
28 num_batches: BatchCount,
29 rounding_behavior: RoundingBehavior,
30 ) -> Result<SendableRecordBatchStream, Error>;
31
32 fn into_df_exec(self, batch_size: RowCount, num_batches: BatchCount) -> Arc<dyn ExecutionPlan>;
33}
34
35impl DatafusionDatagenExt for BatchGeneratorBuilder {
36 fn into_df_stream(
37 self,
38 batch_size: RowCount,
39 num_batches: BatchCount,
40 ) -> SendableRecordBatchStream {
41 let (stream, schema) = self.into_reader_stream(batch_size, num_batches);
42 let stream = stream.map_err(DataFusionError::from);
43 Box::pin(RecordBatchStreamAdapter::new(schema, stream))
44 }
45
46 fn into_df_stream_bytes(
47 self,
48 batch_size: ByteCount,
49 num_batches: BatchCount,
50 rounding_behavior: RoundingBehavior,
51 ) -> Result<SendableRecordBatchStream, Error> {
52 let stream = self.into_reader_bytes(batch_size, num_batches, rounding_behavior)?;
53 let schema = stream.schema();
54 let stream = futures::stream::iter(stream).map_err(DataFusionError::from);
55 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
56 }
57
58 fn into_df_exec(self, batch_size: RowCount, num_batches: BatchCount) -> Arc<dyn ExecutionPlan> {
59 let stream = self.into_df_stream(batch_size, num_batches);
60 Arc::new(OneShotExec::new(stream))
61 }
62}