lance_datafusion/
utils.rs1use std::borrow::Cow;
5
6use arrow::ffi_stream::ArrowArrayStreamReader;
7use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader};
8use arrow_schema::{ArrowError, SchemaRef};
9use async_trait::async_trait;
10use background_iterator::BackgroundIterator;
11use datafusion::{
12 execution::RecordBatchStream,
13 physical_plan::{
14 metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricValue, MetricsSet},
15 stream::RecordBatchStreamAdapter,
16 SendableRecordBatchStream,
17 },
18};
19use datafusion_common::DataFusionError;
20use futures::{stream, StreamExt, TryStreamExt};
21use lance_core::datatypes::Schema;
22use lance_core::Result;
23use tokio::task::spawn;
24
25pub mod background_iterator;
26
27#[async_trait]
33pub trait StreamingWriteSource: Send {
34 async fn into_stream_and_schema(self) -> Result<(SendableRecordBatchStream, Schema)>
42 where
43 Self: Sized,
44 {
45 let mut stream = self.into_stream();
46 let (stream, arrow_schema, schema) = spawn(async move {
47 let arrow_schema = stream.schema();
48 let mut schema: Schema = Schema::try_from(arrow_schema.as_ref())?;
49 let first_batch = stream.try_next().await?;
50 if let Some(batch) = &first_batch {
51 schema.set_dictionary(batch)?;
52 }
53 let stream = stream::iter(first_batch.map(Ok)).chain(stream);
54 Result::Ok((stream, arrow_schema, schema))
55 })
56 .await
57 .unwrap()?;
58 schema.validate()?;
59 let adapter = RecordBatchStreamAdapter::new(arrow_schema, stream);
60 Ok((Box::pin(adapter), schema))
61 }
62
63 fn arrow_schema(&self) -> SchemaRef;
65
66 fn into_stream(self) -> SendableRecordBatchStream;
70}
71
72impl StreamingWriteSource for ArrowArrayStreamReader {
73 #[inline]
74 fn arrow_schema(&self) -> SchemaRef {
75 RecordBatchReader::schema(self)
76 }
77
78 #[inline]
79 fn into_stream(self) -> SendableRecordBatchStream {
80 reader_to_stream(Box::new(self))
81 }
82}
83
84impl<I> StreamingWriteSource for RecordBatchIterator<I>
85where
86 Self: Send,
87 I: IntoIterator<Item = ::core::result::Result<RecordBatch, ArrowError>> + Send + 'static,
88{
89 #[inline]
90 fn arrow_schema(&self) -> SchemaRef {
91 RecordBatchReader::schema(self)
92 }
93
94 #[inline]
95 fn into_stream(self) -> SendableRecordBatchStream {
96 reader_to_stream(Box::new(self))
97 }
98}
99
100impl<T> StreamingWriteSource for Box<T>
101where
102 T: StreamingWriteSource,
103{
104 #[inline]
105 fn arrow_schema(&self) -> SchemaRef {
106 T::arrow_schema(&**self)
107 }
108
109 #[inline]
110 fn into_stream(self) -> SendableRecordBatchStream {
111 T::into_stream(*self)
112 }
113}
114
115impl StreamingWriteSource for Box<dyn RecordBatchReader + Send> {
116 #[inline]
117 fn arrow_schema(&self) -> SchemaRef {
118 RecordBatchReader::schema(self)
119 }
120
121 #[inline]
122 fn into_stream(self) -> SendableRecordBatchStream {
123 reader_to_stream(self)
124 }
125}
126
127impl StreamingWriteSource for SendableRecordBatchStream {
128 #[inline]
129 fn arrow_schema(&self) -> SchemaRef {
130 RecordBatchStream::schema(&**self)
131 }
132
133 #[inline]
134 fn into_stream(self) -> SendableRecordBatchStream {
135 self
136 }
137}
138
139pub fn reader_to_stream(batches: Box<dyn RecordBatchReader + Send>) -> SendableRecordBatchStream {
143 let arrow_schema = batches.arrow_schema();
144 let stream = RecordBatchStreamAdapter::new(
145 arrow_schema,
146 BackgroundIterator::new(batches)
147 .fuse()
148 .map_err(DataFusionError::from),
149 );
150 Box::pin(stream)
151}
152
153pub trait MetricsExt {
154 fn find_count(&self, name: &str) -> Option<Count>;
155}
156
157impl MetricsExt for MetricsSet {
158 fn find_count(&self, metric_name: &str) -> Option<Count> {
159 self.iter().find_map(|m| match m.value() {
160 MetricValue::Count { name, count } => {
161 if name == metric_name {
162 Some(count.clone())
163 } else {
164 None
165 }
166 }
167 _ => None,
168 })
169 }
170}
171
172pub trait ExecutionPlanMetricsSetExt {
173 fn new_count(&self, name: &'static str, partition: usize) -> Count;
174}
175
176impl ExecutionPlanMetricsSetExt for ExecutionPlanMetricsSet {
177 fn new_count(&self, name: &'static str, partition: usize) -> Count {
178 let count = Count::new();
179 MetricBuilder::new(self)
180 .with_partition(partition)
181 .build(MetricValue::Count {
182 name: Cow::Borrowed(name),
183 count: count.clone(),
184 });
185 count
186 }
187}
188
189pub const IOPS_METRIC: &str = "iops";
191pub const REQUESTS_METRIC: &str = "requests";
192pub const BYTES_READ_METRIC: &str = "bytes_read";
193pub const INDICES_LOADED_METRIC: &str = "indices_loaded";
194pub const PARTS_LOADED_METRIC: &str = "parts_loaded";
195pub const INDEX_COMPARISONS_METRIC: &str = "index_comparisons";