lance_datafusion/
utils.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::borrow::Cow;
5
6use arrow::ffi_stream::ArrowArrayStreamReader;
7use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader};
8use arrow_schema::{ArrowError, SchemaRef};
9use async_trait::async_trait;
10use datafusion::{
11    execution::RecordBatchStream,
12    physical_plan::{
13        metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricValue, MetricsSet},
14        stream::RecordBatchStreamAdapter,
15        SendableRecordBatchStream,
16    },
17};
18use datafusion_common::DataFusionError;
19use futures::{stream, Stream, StreamExt, TryFutureExt, TryStreamExt};
20use lance_core::datatypes::Schema;
21use lance_core::Result;
22use tokio::task::{spawn, spawn_blocking};
23
24fn background_iterator<I: Iterator + Send + 'static>(iter: I) -> impl Stream<Item = I::Item>
25where
26    I::Item: Send,
27{
28    stream::unfold(iter, |mut iter| {
29        spawn_blocking(|| iter.next().map(|val| (val, iter)))
30            .unwrap_or_else(|err| panic!("{}", err))
31    })
32    .fuse()
33}
34
35/// A trait for [BatchRecord] iterators, readers and streams
36/// that can be converted to a concrete stream type [SendableRecordBatchStream].
37///
38/// This also cam read the schema from the first batch
39/// and then update the schema to reflect the dictionary columns.
40#[async_trait]
41pub trait StreamingWriteSource: Send {
42    /// Infer the Lance schema from the first batch stream.
43    ///
44    /// This will peek the first batch to get the dictionaries for dictionary columns.
45    ///
46    /// NOTE: this does not validate the schema. For example, for appends the schema
47    /// should be checked to make sure it matches the existing dataset schema before
48    /// writing.
49    async fn into_stream_and_schema(self) -> Result<(SendableRecordBatchStream, Schema)>
50    where
51        Self: Sized,
52    {
53        let mut stream = self.into_stream();
54        let (stream, arrow_schema, schema) = spawn(async move {
55            let arrow_schema = stream.schema();
56            let mut schema: Schema = Schema::try_from(arrow_schema.as_ref())?;
57            let first_batch = stream.try_next().await?;
58            if let Some(batch) = &first_batch {
59                schema.set_dictionary(batch)?;
60            }
61            let stream = stream::iter(first_batch.map(Ok)).chain(stream);
62            Result::Ok((stream, arrow_schema, schema))
63        })
64        .await
65        .unwrap()?;
66        schema.validate()?;
67        let adapter = RecordBatchStreamAdapter::new(arrow_schema, stream);
68        Ok((Box::pin(adapter), schema))
69    }
70
71    /// Returns the arrow schema.
72    fn arrow_schema(&self) -> SchemaRef;
73
74    /// Convert to a stream.
75    ///
76    /// The conversion will be conducted in a background thread.
77    fn into_stream(self) -> SendableRecordBatchStream;
78}
79
80impl StreamingWriteSource for ArrowArrayStreamReader {
81    #[inline]
82    fn arrow_schema(&self) -> SchemaRef {
83        RecordBatchReader::schema(self)
84    }
85
86    #[inline]
87    fn into_stream(self) -> SendableRecordBatchStream {
88        reader_to_stream(Box::new(self))
89    }
90}
91
92impl<I> StreamingWriteSource for RecordBatchIterator<I>
93where
94    Self: Send,
95    I: IntoIterator<Item = ::core::result::Result<RecordBatch, ArrowError>> + Send + 'static,
96{
97    #[inline]
98    fn arrow_schema(&self) -> SchemaRef {
99        RecordBatchReader::schema(self)
100    }
101
102    #[inline]
103    fn into_stream(self) -> SendableRecordBatchStream {
104        reader_to_stream(Box::new(self))
105    }
106}
107
108impl<T> StreamingWriteSource for Box<T>
109where
110    T: StreamingWriteSource,
111{
112    #[inline]
113    fn arrow_schema(&self) -> SchemaRef {
114        T::arrow_schema(&**self)
115    }
116
117    #[inline]
118    fn into_stream(self) -> SendableRecordBatchStream {
119        T::into_stream(*self)
120    }
121}
122
123impl StreamingWriteSource for Box<dyn RecordBatchReader + Send> {
124    #[inline]
125    fn arrow_schema(&self) -> SchemaRef {
126        RecordBatchReader::schema(self)
127    }
128
129    #[inline]
130    fn into_stream(self) -> SendableRecordBatchStream {
131        reader_to_stream(self)
132    }
133}
134
135impl StreamingWriteSource for SendableRecordBatchStream {
136    #[inline]
137    fn arrow_schema(&self) -> SchemaRef {
138        RecordBatchStream::schema(&**self)
139    }
140
141    #[inline]
142    fn into_stream(self) -> SendableRecordBatchStream {
143        self
144    }
145}
146
147/// Convert reader to a stream.
148///
149/// The reader will be called in a background thread.
150pub fn reader_to_stream(batches: Box<dyn RecordBatchReader + Send>) -> SendableRecordBatchStream {
151    let arrow_schema = batches.arrow_schema();
152    let stream = RecordBatchStreamAdapter::new(
153        arrow_schema,
154        background_iterator(batches).map_err(DataFusionError::from),
155    );
156    Box::pin(stream)
157}
158
159pub trait MetricsExt {
160    fn find_count(&self, name: &str) -> Option<Count>;
161}
162
163impl MetricsExt for MetricsSet {
164    fn find_count(&self, metric_name: &str) -> Option<Count> {
165        self.iter().find_map(|m| match m.value() {
166            MetricValue::Count { name, count } => {
167                if name == metric_name {
168                    Some(count.clone())
169                } else {
170                    None
171                }
172            }
173            _ => None,
174        })
175    }
176}
177
178pub trait ExecutionPlanMetricsSetExt {
179    fn new_count(&self, name: &'static str, partition: usize) -> Count;
180}
181
182impl ExecutionPlanMetricsSetExt for ExecutionPlanMetricsSet {
183    fn new_count(&self, name: &'static str, partition: usize) -> Count {
184        let count = Count::new();
185        MetricBuilder::new(self)
186            .with_partition(partition)
187            .build(MetricValue::Count {
188                name: Cow::Borrowed(name),
189                count: count.clone(),
190            });
191        count
192    }
193}
194
195// Common metrics
196pub const IOPS_METRIC: &str = "iops";
197pub const REQUESTS_METRIC: &str = "requests";
198pub const BYTES_READ_METRIC: &str = "bytes_read";
199pub const INDICES_LOADED_METRIC: &str = "indices_loaded";
200pub const PARTS_LOADED_METRIC: &str = "parts_loaded";
201pub const INDEX_COMPARISONS_METRIC: &str = "index_comparisons";