lance_datafusion/
utils.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::borrow::Cow;
5
6use arrow::ffi_stream::ArrowArrayStreamReader;
7use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader};
8use arrow_schema::{ArrowError, SchemaRef};
9use async_trait::async_trait;
10use background_iterator::BackgroundIterator;
11use datafusion::{
12    execution::RecordBatchStream,
13    physical_plan::{
14        metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricValue, MetricsSet},
15        stream::RecordBatchStreamAdapter,
16        SendableRecordBatchStream,
17    },
18};
19use datafusion_common::DataFusionError;
20use futures::{stream, StreamExt, TryStreamExt};
21use lance_core::datatypes::Schema;
22use lance_core::Result;
23use tokio::task::spawn;
24
25pub mod background_iterator;
26
27/// A trait for [BatchRecord] iterators, readers and streams
28/// that can be converted to a concrete stream type [SendableRecordBatchStream].
29///
30/// This also cam read the schema from the first batch
31/// and then update the schema to reflect the dictionary columns.
32#[async_trait]
33pub trait StreamingWriteSource: Send {
34    /// Infer the Lance schema from the first batch stream.
35    ///
36    /// This will peek the first batch to get the dictionaries for dictionary columns.
37    ///
38    /// NOTE: this does not validate the schema. For example, for appends the schema
39    /// should be checked to make sure it matches the existing dataset schema before
40    /// writing.
41    async fn into_stream_and_schema(self) -> Result<(SendableRecordBatchStream, Schema)>
42    where
43        Self: Sized,
44    {
45        let mut stream = self.into_stream();
46        let (stream, arrow_schema, schema) = spawn(async move {
47            let arrow_schema = stream.schema();
48            let mut schema: Schema = Schema::try_from(arrow_schema.as_ref())?;
49            let first_batch = stream.try_next().await?;
50            if let Some(batch) = &first_batch {
51                schema.set_dictionary(batch)?;
52            }
53            let stream = stream::iter(first_batch.map(Ok)).chain(stream);
54            Result::Ok((stream, arrow_schema, schema))
55        })
56        .await
57        .unwrap()?;
58        schema.validate()?;
59        let adapter = RecordBatchStreamAdapter::new(arrow_schema, stream);
60        Ok((Box::pin(adapter), schema))
61    }
62
63    /// Returns the arrow schema.
64    fn arrow_schema(&self) -> SchemaRef;
65
66    /// Convert to a stream.
67    ///
68    /// The conversion will be conducted in a background thread.
69    fn into_stream(self) -> SendableRecordBatchStream;
70}
71
72impl StreamingWriteSource for ArrowArrayStreamReader {
73    #[inline]
74    fn arrow_schema(&self) -> SchemaRef {
75        RecordBatchReader::schema(self)
76    }
77
78    #[inline]
79    fn into_stream(self) -> SendableRecordBatchStream {
80        reader_to_stream(Box::new(self))
81    }
82}
83
84impl<I> StreamingWriteSource for RecordBatchIterator<I>
85where
86    Self: Send,
87    I: IntoIterator<Item = ::core::result::Result<RecordBatch, ArrowError>> + Send + 'static,
88{
89    #[inline]
90    fn arrow_schema(&self) -> SchemaRef {
91        RecordBatchReader::schema(self)
92    }
93
94    #[inline]
95    fn into_stream(self) -> SendableRecordBatchStream {
96        reader_to_stream(Box::new(self))
97    }
98}
99
100impl<T> StreamingWriteSource for Box<T>
101where
102    T: StreamingWriteSource,
103{
104    #[inline]
105    fn arrow_schema(&self) -> SchemaRef {
106        T::arrow_schema(&**self)
107    }
108
109    #[inline]
110    fn into_stream(self) -> SendableRecordBatchStream {
111        T::into_stream(*self)
112    }
113}
114
115impl StreamingWriteSource for Box<dyn RecordBatchReader + Send> {
116    #[inline]
117    fn arrow_schema(&self) -> SchemaRef {
118        RecordBatchReader::schema(self)
119    }
120
121    #[inline]
122    fn into_stream(self) -> SendableRecordBatchStream {
123        reader_to_stream(self)
124    }
125}
126
127impl StreamingWriteSource for SendableRecordBatchStream {
128    #[inline]
129    fn arrow_schema(&self) -> SchemaRef {
130        RecordBatchStream::schema(&**self)
131    }
132
133    #[inline]
134    fn into_stream(self) -> SendableRecordBatchStream {
135        self
136    }
137}
138
139/// Convert reader to a stream.
140///
141/// The reader will be called in a background thread.
142pub fn reader_to_stream(batches: Box<dyn RecordBatchReader + Send>) -> SendableRecordBatchStream {
143    let arrow_schema = batches.arrow_schema();
144    let stream = RecordBatchStreamAdapter::new(
145        arrow_schema,
146        BackgroundIterator::new(batches)
147            .fuse()
148            .map_err(DataFusionError::from),
149    );
150    Box::pin(stream)
151}
152
153pub trait MetricsExt {
154    fn find_count(&self, name: &str) -> Option<Count>;
155}
156
157impl MetricsExt for MetricsSet {
158    fn find_count(&self, metric_name: &str) -> Option<Count> {
159        self.iter().find_map(|m| match m.value() {
160            MetricValue::Count { name, count } => {
161                if name == metric_name {
162                    Some(count.clone())
163                } else {
164                    None
165                }
166            }
167            _ => None,
168        })
169    }
170}
171
172pub trait ExecutionPlanMetricsSetExt {
173    fn new_count(&self, name: &'static str, partition: usize) -> Count;
174}
175
176impl ExecutionPlanMetricsSetExt for ExecutionPlanMetricsSet {
177    fn new_count(&self, name: &'static str, partition: usize) -> Count {
178        let count = Count::new();
179        MetricBuilder::new(self)
180            .with_partition(partition)
181            .build(MetricValue::Count {
182                name: Cow::Borrowed(name),
183                count: count.clone(),
184            });
185        count
186    }
187}
188
189// Common metrics
190pub const IOPS_METRIC: &str = "iops";
191pub const REQUESTS_METRIC: &str = "requests";
192pub const BYTES_READ_METRIC: &str = "bytes_read";
193pub const INDICES_LOADED_METRIC: &str = "indices_loaded";
194pub const PARTS_LOADED_METRIC: &str = "parts_loaded";
195pub const INDEX_COMPARISONS_METRIC: &str = "index_comparisons";