lance_datafusion/
utils.rs1use std::borrow::Cow;
5
6use arrow::ffi_stream::ArrowArrayStreamReader;
7use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader};
8use arrow_schema::{ArrowError, SchemaRef};
9use async_trait::async_trait;
10use background_iterator::BackgroundIterator;
11use datafusion::{
12 execution::RecordBatchStream,
13 physical_plan::{
14 metrics::{
15 Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder, MetricValue, MetricsSet, Time,
16 },
17 stream::RecordBatchStreamAdapter,
18 SendableRecordBatchStream,
19 },
20};
21use datafusion_common::DataFusionError;
22use futures::{stream, StreamExt, TryStreamExt};
23use lance_core::datatypes::Schema;
24use lance_core::Result;
25use tokio::task::spawn;
26
27pub mod background_iterator;
28
29#[async_trait]
35pub trait StreamingWriteSource: Send {
36 async fn into_stream_and_schema(self) -> Result<(SendableRecordBatchStream, Schema)>
44 where
45 Self: Sized,
46 {
47 let mut stream = self.into_stream();
48 let (stream, arrow_schema, schema) = spawn(async move {
49 let arrow_schema = stream.schema();
50 let mut schema: Schema = Schema::try_from(arrow_schema.as_ref())?;
51 let first_batch = stream.try_next().await?;
52 if let Some(batch) = &first_batch {
53 schema.set_dictionary(batch)?;
54 }
55 let stream = stream::iter(first_batch.map(Ok)).chain(stream);
56 Result::Ok((stream, arrow_schema, schema))
57 })
58 .await
59 .unwrap()?;
60 schema.validate()?;
61 let adapter = RecordBatchStreamAdapter::new(arrow_schema, stream);
62 Ok((Box::pin(adapter), schema))
63 }
64
65 fn arrow_schema(&self) -> SchemaRef;
67
68 fn into_stream(self) -> SendableRecordBatchStream;
72}
73
74impl StreamingWriteSource for ArrowArrayStreamReader {
75 #[inline]
76 fn arrow_schema(&self) -> SchemaRef {
77 RecordBatchReader::schema(self)
78 }
79
80 #[inline]
81 fn into_stream(self) -> SendableRecordBatchStream {
82 reader_to_stream(Box::new(self))
83 }
84}
85
86impl<I> StreamingWriteSource for RecordBatchIterator<I>
87where
88 Self: Send,
89 I: IntoIterator<Item = ::core::result::Result<RecordBatch, ArrowError>> + Send + 'static,
90{
91 #[inline]
92 fn arrow_schema(&self) -> SchemaRef {
93 RecordBatchReader::schema(self)
94 }
95
96 #[inline]
97 fn into_stream(self) -> SendableRecordBatchStream {
98 reader_to_stream(Box::new(self))
99 }
100}
101
102impl<T> StreamingWriteSource for Box<T>
103where
104 T: StreamingWriteSource,
105{
106 #[inline]
107 fn arrow_schema(&self) -> SchemaRef {
108 T::arrow_schema(&**self)
109 }
110
111 #[inline]
112 fn into_stream(self) -> SendableRecordBatchStream {
113 T::into_stream(*self)
114 }
115}
116
117impl StreamingWriteSource for Box<dyn RecordBatchReader + Send> {
118 #[inline]
119 fn arrow_schema(&self) -> SchemaRef {
120 RecordBatchReader::schema(self)
121 }
122
123 #[inline]
124 fn into_stream(self) -> SendableRecordBatchStream {
125 reader_to_stream(self)
126 }
127}
128
129impl StreamingWriteSource for SendableRecordBatchStream {
130 #[inline]
131 fn arrow_schema(&self) -> SchemaRef {
132 RecordBatchStream::schema(&**self)
133 }
134
135 #[inline]
136 fn into_stream(self) -> SendableRecordBatchStream {
137 self
138 }
139}
140
141pub fn reader_to_stream(batches: Box<dyn RecordBatchReader + Send>) -> SendableRecordBatchStream {
145 let arrow_schema = batches.arrow_schema();
146 let stream = RecordBatchStreamAdapter::new(
147 arrow_schema,
148 BackgroundIterator::new(batches)
149 .fuse()
150 .map_err(DataFusionError::from),
151 );
152 Box::pin(stream)
153}
154
155pub trait MetricsExt {
156 fn find_count(&self, name: &str) -> Option<Count>;
157 fn iter_counts(&self) -> impl Iterator<Item = (impl AsRef<str>, &Count)>;
158 fn iter_gauges(&self) -> impl Iterator<Item = (impl AsRef<str>, &Gauge)>;
159}
160
161impl MetricsExt for MetricsSet {
162 fn find_count(&self, metric_name: &str) -> Option<Count> {
163 self.iter().find_map(|m| match m.value() {
164 MetricValue::Count { name, count } => {
165 if name == metric_name {
166 Some(count.clone())
167 } else {
168 None
169 }
170 }
171 _ => None,
172 })
173 }
174
175 fn iter_counts(&self) -> impl Iterator<Item = (impl AsRef<str>, &Count)> {
176 self.iter().filter_map(|m| match m.value() {
177 MetricValue::Count { name, count } => Some((name, count)),
178 _ => None,
179 })
180 }
181
182 fn iter_gauges(&self) -> impl Iterator<Item = (impl AsRef<str>, &Gauge)> {
183 self.iter().filter_map(|m| match m.value() {
184 MetricValue::Gauge { name, gauge } => Some((name, gauge)),
185 _ => None,
186 })
187 }
188}
189
190pub trait ExecutionPlanMetricsSetExt {
191 fn new_count(&self, name: &'static str, partition: usize) -> Count;
192 fn new_time(&self, name: &'static str, partition: usize) -> Time;
193 fn new_gauge(&self, name: &'static str, partition: usize) -> Gauge;
194}
195
196impl ExecutionPlanMetricsSetExt for ExecutionPlanMetricsSet {
197 fn new_count(&self, name: &'static str, partition: usize) -> Count {
198 let count = Count::new();
199 MetricBuilder::new(self)
200 .with_partition(partition)
201 .build(MetricValue::Count {
202 name: Cow::Borrowed(name),
203 count: count.clone(),
204 });
205 count
206 }
207
208 fn new_time(&self, name: &'static str, partition: usize) -> Time {
209 let time = Time::new();
210 MetricBuilder::new(self)
211 .with_partition(partition)
212 .build(MetricValue::Time {
213 name: Cow::Borrowed(name),
214 time: time.clone(),
215 });
216 time
217 }
218
219 fn new_gauge(&self, name: &'static str, partition: usize) -> Gauge {
220 let gauge = Gauge::new();
221 MetricBuilder::new(self)
222 .with_partition(partition)
223 .build(MetricValue::Gauge {
224 name: Cow::Borrowed(name),
225 gauge: gauge.clone(),
226 });
227 gauge
228 }
229}
230
231pub const IOPS_METRIC: &str = "iops";
233pub const REQUESTS_METRIC: &str = "requests";
234pub const BYTES_READ_METRIC: &str = "bytes_read";
235pub const INDICES_LOADED_METRIC: &str = "indices_loaded";
236pub const PARTS_LOADED_METRIC: &str = "parts_loaded";
237pub const PARTITIONS_RANKED_METRIC: &str = "partitions_ranked";
238pub const INDEX_COMPARISONS_METRIC: &str = "index_comparisons";
239pub const FRAGMENTS_SCANNED_METRIC: &str = "fragments_scanned";
240pub const RANGES_SCANNED_METRIC: &str = "ranges_scanned";
241pub const ROWS_SCANNED_METRIC: &str = "rows_scanned";
242pub const TASK_WAIT_TIME_METRIC: &str = "task_wait_time";
243pub const DELTAS_SEARCHED_METRIC: &str = "deltas_searched";
244pub const PARTITIONS_SEARCHED_METRIC: &str = "partitions_searched";
245pub const SCALAR_INDEX_SEARCH_TIME_METRIC: &str = "search_time";
246pub const SCALAR_INDEX_SER_TIME_METRIC: &str = "ser_time";