lance_datafusion/
utils.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::borrow::Cow;
5
6use arrow::ffi_stream::ArrowArrayStreamReader;
7use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader};
8use arrow_schema::{ArrowError, SchemaRef};
9use async_trait::async_trait;
10use background_iterator::BackgroundIterator;
11use datafusion::{
12    execution::RecordBatchStream,
13    physical_plan::{
14        metrics::{
15            Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder, MetricValue, MetricsSet, Time,
16        },
17        stream::RecordBatchStreamAdapter,
18        SendableRecordBatchStream,
19    },
20};
21use datafusion_common::DataFusionError;
22use futures::{stream, StreamExt, TryStreamExt};
23use lance_core::datatypes::Schema;
24use lance_core::Result;
25use tokio::task::spawn;
26
27pub mod background_iterator;
28
29/// A trait for [BatchRecord] iterators, readers and streams
30/// that can be converted to a concrete stream type [SendableRecordBatchStream].
31///
32/// This also cam read the schema from the first batch
33/// and then update the schema to reflect the dictionary columns.
34#[async_trait]
35pub trait StreamingWriteSource: Send {
36    /// Infer the Lance schema from the first batch stream.
37    ///
38    /// This will peek the first batch to get the dictionaries for dictionary columns.
39    ///
40    /// NOTE: this does not validate the schema. For example, for appends the schema
41    /// should be checked to make sure it matches the existing dataset schema before
42    /// writing.
43    async fn into_stream_and_schema(self) -> Result<(SendableRecordBatchStream, Schema)>
44    where
45        Self: Sized,
46    {
47        let mut stream = self.into_stream();
48        let (stream, arrow_schema, schema) = spawn(async move {
49            let arrow_schema = stream.schema();
50            let mut schema: Schema = Schema::try_from(arrow_schema.as_ref())?;
51            let first_batch = stream.try_next().await?;
52            if let Some(batch) = &first_batch {
53                schema.set_dictionary(batch)?;
54            }
55            let stream = stream::iter(first_batch.map(Ok)).chain(stream);
56            Result::Ok((stream, arrow_schema, schema))
57        })
58        .await
59        .unwrap()?;
60        schema.validate()?;
61        let adapter = RecordBatchStreamAdapter::new(arrow_schema, stream);
62        Ok((Box::pin(adapter), schema))
63    }
64
65    /// Returns the arrow schema.
66    fn arrow_schema(&self) -> SchemaRef;
67
68    /// Convert to a stream.
69    ///
70    /// The conversion will be conducted in a background thread.
71    fn into_stream(self) -> SendableRecordBatchStream;
72}
73
74impl StreamingWriteSource for ArrowArrayStreamReader {
75    #[inline]
76    fn arrow_schema(&self) -> SchemaRef {
77        RecordBatchReader::schema(self)
78    }
79
80    #[inline]
81    fn into_stream(self) -> SendableRecordBatchStream {
82        reader_to_stream(Box::new(self))
83    }
84}
85
86impl<I> StreamingWriteSource for RecordBatchIterator<I>
87where
88    Self: Send,
89    I: IntoIterator<Item = ::core::result::Result<RecordBatch, ArrowError>> + Send + 'static,
90{
91    #[inline]
92    fn arrow_schema(&self) -> SchemaRef {
93        RecordBatchReader::schema(self)
94    }
95
96    #[inline]
97    fn into_stream(self) -> SendableRecordBatchStream {
98        reader_to_stream(Box::new(self))
99    }
100}
101
102impl<T> StreamingWriteSource for Box<T>
103where
104    T: StreamingWriteSource,
105{
106    #[inline]
107    fn arrow_schema(&self) -> SchemaRef {
108        T::arrow_schema(&**self)
109    }
110
111    #[inline]
112    fn into_stream(self) -> SendableRecordBatchStream {
113        T::into_stream(*self)
114    }
115}
116
117impl StreamingWriteSource for Box<dyn RecordBatchReader + Send> {
118    #[inline]
119    fn arrow_schema(&self) -> SchemaRef {
120        RecordBatchReader::schema(self)
121    }
122
123    #[inline]
124    fn into_stream(self) -> SendableRecordBatchStream {
125        reader_to_stream(self)
126    }
127}
128
129impl StreamingWriteSource for SendableRecordBatchStream {
130    #[inline]
131    fn arrow_schema(&self) -> SchemaRef {
132        RecordBatchStream::schema(&**self)
133    }
134
135    #[inline]
136    fn into_stream(self) -> SendableRecordBatchStream {
137        self
138    }
139}
140
141/// Convert reader to a stream.
142///
143/// The reader will be called in a background thread.
144pub fn reader_to_stream(batches: Box<dyn RecordBatchReader + Send>) -> SendableRecordBatchStream {
145    let arrow_schema = batches.arrow_schema();
146    let stream = RecordBatchStreamAdapter::new(
147        arrow_schema,
148        BackgroundIterator::new(batches)
149            .fuse()
150            .map_err(DataFusionError::from),
151    );
152    Box::pin(stream)
153}
154
155pub trait MetricsExt {
156    fn find_count(&self, name: &str) -> Option<Count>;
157    fn iter_counts(&self) -> impl Iterator<Item = (impl AsRef<str>, &Count)>;
158    fn iter_gauges(&self) -> impl Iterator<Item = (impl AsRef<str>, &Gauge)>;
159}
160
161impl MetricsExt for MetricsSet {
162    fn find_count(&self, metric_name: &str) -> Option<Count> {
163        self.iter().find_map(|m| match m.value() {
164            MetricValue::Count { name, count } => {
165                if name == metric_name {
166                    Some(count.clone())
167                } else {
168                    None
169                }
170            }
171            _ => None,
172        })
173    }
174
175    fn iter_counts(&self) -> impl Iterator<Item = (impl AsRef<str>, &Count)> {
176        self.iter().filter_map(|m| match m.value() {
177            MetricValue::Count { name, count } => Some((name, count)),
178            _ => None,
179        })
180    }
181
182    fn iter_gauges(&self) -> impl Iterator<Item = (impl AsRef<str>, &Gauge)> {
183        self.iter().filter_map(|m| match m.value() {
184            MetricValue::Gauge { name, gauge } => Some((name, gauge)),
185            _ => None,
186        })
187    }
188}
189
190pub trait ExecutionPlanMetricsSetExt {
191    fn new_count(&self, name: &'static str, partition: usize) -> Count;
192    fn new_time(&self, name: &'static str, partition: usize) -> Time;
193    fn new_gauge(&self, name: &'static str, partition: usize) -> Gauge;
194}
195
196impl ExecutionPlanMetricsSetExt for ExecutionPlanMetricsSet {
197    fn new_count(&self, name: &'static str, partition: usize) -> Count {
198        let count = Count::new();
199        MetricBuilder::new(self)
200            .with_partition(partition)
201            .build(MetricValue::Count {
202                name: Cow::Borrowed(name),
203                count: count.clone(),
204            });
205        count
206    }
207
208    fn new_time(&self, name: &'static str, partition: usize) -> Time {
209        let time = Time::new();
210        MetricBuilder::new(self)
211            .with_partition(partition)
212            .build(MetricValue::Time {
213                name: Cow::Borrowed(name),
214                time: time.clone(),
215            });
216        time
217    }
218
219    fn new_gauge(&self, name: &'static str, partition: usize) -> Gauge {
220        let gauge = Gauge::new();
221        MetricBuilder::new(self)
222            .with_partition(partition)
223            .build(MetricValue::Gauge {
224                name: Cow::Borrowed(name),
225                gauge: gauge.clone(),
226            });
227        gauge
228    }
229}
230
231// Common metrics
232pub const IOPS_METRIC: &str = "iops";
233pub const REQUESTS_METRIC: &str = "requests";
234pub const BYTES_READ_METRIC: &str = "bytes_read";
235pub const INDICES_LOADED_METRIC: &str = "indices_loaded";
236pub const PARTS_LOADED_METRIC: &str = "parts_loaded";
237pub const PARTITIONS_RANKED_METRIC: &str = "partitions_ranked";
238pub const INDEX_COMPARISONS_METRIC: &str = "index_comparisons";
239pub const FRAGMENTS_SCANNED_METRIC: &str = "fragments_scanned";
240pub const RANGES_SCANNED_METRIC: &str = "ranges_scanned";
241pub const ROWS_SCANNED_METRIC: &str = "rows_scanned";
242pub const TASK_WAIT_TIME_METRIC: &str = "task_wait_time";
243pub const DELTAS_SEARCHED_METRIC: &str = "deltas_searched";
244pub const PARTITIONS_SEARCHED_METRIC: &str = "partitions_searched";
245pub const SCALAR_INDEX_SEARCH_TIME_METRIC: &str = "search_time";
246pub const SCALAR_INDEX_SER_TIME_METRIC: &str = "ser_time";