orc_rust/
async_arrow_reader.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::Formatter;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22
23use arrow::datatypes::SchemaRef;
24use arrow::error::ArrowError;
25use arrow::record_batch::RecordBatch;
26use futures::future::BoxFuture;
27use futures::{ready, Stream};
28use futures_util::FutureExt;
29
30use crate::array_decoder::NaiveStripeDecoder;
31use crate::arrow_reader::Cursor;
32use crate::error::Result;
33use crate::predicate::Predicate;
34use crate::reader::metadata::read_metadata_async;
35use crate::reader::AsyncChunkReader;
36use crate::row_group_filter::evaluate_predicate;
37use crate::row_selection::RowSelection;
38use crate::schema::RootDataType;
39use crate::stripe::{Stripe, StripeMetadata};
40use crate::ArrowReaderBuilder;
41
42type BoxedDecoder = Box<dyn Iterator<Item = Result<RecordBatch>> + Send>;
43
44enum StreamState<T> {
45    /// At the start of a new row group, or the end of the file stream
46    Init,
47    /// Decoding a batch
48    Decoding(BoxedDecoder),
49    /// Reading data from input
50    Reading(BoxFuture<'static, Result<(StripeFactory<T>, Option<Stripe>)>>),
51    /// Error
52    Error,
53}
54
55impl<T> std::fmt::Debug for StreamState<T> {
56    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
57        match self {
58            StreamState::Init => write!(f, "StreamState::Init"),
59            StreamState::Decoding(_) => write!(f, "StreamState::Decoding"),
60            StreamState::Reading(_) => write!(f, "StreamState::Reading"),
61            StreamState::Error => write!(f, "StreamState::Error"),
62        }
63    }
64}
65
66impl<R: Send> From<Cursor<R>> for StripeFactory<R> {
67    fn from(c: Cursor<R>) -> Self {
68        Self {
69            inner: c,
70            is_end: false,
71        }
72    }
73}
74
75pub struct StripeFactory<R> {
76    inner: Cursor<R>,
77    is_end: bool,
78}
79
80pub struct ArrowStreamReader<R: AsyncChunkReader> {
81    factory: Option<Box<StripeFactory<R>>>,
82    batch_size: usize,
83    schema_ref: SchemaRef,
84    row_selection: Option<RowSelection>,
85    predicate: Option<Predicate>,
86    projected_data_type: RootDataType,
87    file_metadata: Arc<crate::reader::metadata::FileMetadata>,
88    state: StreamState<R>,
89}
90
91impl<R: AsyncChunkReader + 'static> StripeFactory<R> {
92    async fn read_next_stripe_inner(&mut self, info: &StripeMetadata) -> Result<Stripe> {
93        let inner = &mut self.inner;
94
95        inner.stripe_index += 1;
96
97        Stripe::new_async(
98            &mut inner.reader,
99            &inner.file_metadata,
100            &inner.projected_data_type,
101            info,
102        )
103        .await
104    }
105
106    /// Read the next stripe from the file.
107    pub async fn read_next_stripe(mut self) -> Result<(Self, Option<Stripe>)> {
108        let info = self
109            .inner
110            .file_metadata
111            .stripe_metadatas()
112            .get(self.inner.stripe_index)
113            .cloned();
114
115        if let Some(info) = info {
116            if let Some(range) = self.inner.file_byte_range.clone() {
117                let offset = info.offset() as usize;
118                if !range.contains(&offset) {
119                    self.inner.stripe_index += 1;
120                    return Ok((self, None));
121                }
122            }
123            match self.read_next_stripe_inner(&info).await {
124                Ok(stripe) => Ok((self, Some(stripe))),
125                Err(err) => Err(err),
126            }
127        } else {
128            self.is_end = true;
129            Ok((self, None))
130        }
131    }
132}
133
134impl<R: AsyncChunkReader + 'static> ArrowStreamReader<R> {
135    pub(crate) fn new(
136        cursor: Cursor<R>,
137        batch_size: usize,
138        schema_ref: SchemaRef,
139        row_selection: Option<RowSelection>,
140        predicate: Option<Predicate>,
141        projected_data_type: RootDataType,
142        file_metadata: Arc<crate::reader::metadata::FileMetadata>,
143    ) -> Self {
144        Self {
145            factory: Some(Box::new(cursor.into())),
146            batch_size,
147            schema_ref,
148            row_selection,
149            predicate,
150            projected_data_type,
151            file_metadata,
152            state: StreamState::Init,
153        }
154    }
155
156    /// Extracts the inner `StripeFactory` and `SchemaRef` from the `ArrowStreamReader`.
157    pub fn into_parts(self) -> (Option<Box<StripeFactory<R>>>, SchemaRef) {
158        (self.factory, self.schema_ref)
159    }
160
161    pub fn schema(&self) -> SchemaRef {
162        self.schema_ref.clone()
163    }
164
165    fn poll_next_inner(
166        mut self: Pin<&mut Self>,
167        cx: &mut Context<'_>,
168    ) -> Poll<Option<Result<RecordBatch>>> {
169        loop {
170            match &mut self.state {
171                StreamState::Decoding(decoder) => match decoder.next() {
172                    Some(Ok(batch)) => {
173                        return Poll::Ready(Some(Ok(batch)));
174                    }
175                    Some(Err(e)) => {
176                        self.state = StreamState::Error;
177                        return Poll::Ready(Some(Err(e)));
178                    }
179                    None => self.state = StreamState::Init,
180                },
181                StreamState::Init => {
182                    let factory = self.factory.take().expect("lost factory");
183                    if factory.is_end {
184                        return Poll::Ready(None);
185                    }
186
187                    let fut = factory.read_next_stripe().boxed();
188
189                    self.state = StreamState::Reading(fut)
190                }
191                StreamState::Reading(f) => match ready!(f.poll_unpin(cx)) {
192                    Ok((factory, Some(stripe))) => {
193                        self.factory = Some(Box::new(factory));
194
195                        let stripe_rows = stripe.number_of_rows();
196
197                        // Evaluate predicate if present
198                        let mut stripe_selection: Option<RowSelection> = None;
199                        if let Some(ref predicate) = self.predicate {
200                            // Try to read row indexes for this stripe
201                            match stripe.read_row_indexes(&self.file_metadata) {
202                                Ok(row_index) => {
203                                    // Evaluate predicate against row group statistics
204                                    match evaluate_predicate(
205                                        predicate,
206                                        &row_index,
207                                        &self.projected_data_type,
208                                    ) {
209                                        Ok(row_group_filter) => {
210                                            // Generate RowSelection from filter results
211                                            let rows_per_group = self
212                                                .file_metadata
213                                                .row_index_stride()
214                                                .unwrap_or(10_000);
215                                            stripe_selection =
216                                                Some(RowSelection::from_row_group_filter(
217                                                    &row_group_filter,
218                                                    rows_per_group,
219                                                    stripe_rows,
220                                                ));
221                                        }
222                                        Err(_) => {
223                                            // Predicate evaluation failed (e.g., column not found)
224                                            // Keep all rows (maybe)
225                                            stripe_selection =
226                                                Some(RowSelection::select_all(stripe_rows));
227                                        }
228                                    }
229                                }
230                                Err(_) => {
231                                    // Row indexes not available, keep all rows (maybe)
232                                    stripe_selection = Some(RowSelection::select_all(stripe_rows));
233                                }
234                            }
235                        }
236
237                        // Combine with existing row_selection if present
238                        let mut final_selection = stripe_selection;
239                        if let Some(ref mut existing_selection) = self.row_selection {
240                            if existing_selection.row_count() > 0 {
241                                let existing_for_stripe = existing_selection.split_off(stripe_rows);
242                                final_selection = match final_selection {
243                                    Some(predicate_selection) => {
244                                        // Both predicate and manual selection: combine with AND
245                                        Some(existing_for_stripe.and_then(&predicate_selection))
246                                    }
247                                    None => Some(existing_for_stripe),
248                                };
249                            }
250                        }
251
252                        match NaiveStripeDecoder::new_with_selection(
253                            stripe,
254                            self.schema_ref.clone(),
255                            self.batch_size,
256                            final_selection,
257                        ) {
258                            Ok(decoder) => {
259                                self.state = StreamState::Decoding(Box::new(decoder));
260                            }
261                            Err(e) => {
262                                self.state = StreamState::Error;
263                                return Poll::Ready(Some(Err(e)));
264                            }
265                        }
266                    }
267                    Ok((factory, None)) => {
268                        self.factory = Some(Box::new(factory));
269                        // All rows skipped, read next row group
270                        self.state = StreamState::Init;
271                    }
272                    Err(e) => {
273                        self.state = StreamState::Error;
274                        return Poll::Ready(Some(Err(e)));
275                    }
276                },
277                StreamState::Error => return Poll::Ready(None), // Ends the stream as error happens.
278            }
279        }
280    }
281}
282
283impl<R: AsyncChunkReader + 'static> Stream for ArrowStreamReader<R> {
284    type Item = Result<RecordBatch, ArrowError>;
285
286    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
287        self.poll_next_inner(cx)
288            .map_err(|e| ArrowError::ExternalError(Box::new(e)))
289    }
290}
291
292impl<R: AsyncChunkReader + 'static> ArrowReaderBuilder<R> {
293    pub async fn try_new_async(mut reader: R) -> Result<Self> {
294        let file_metadata = Arc::new(read_metadata_async(&mut reader).await?);
295        Ok(Self::new(reader, file_metadata))
296    }
297
298    pub fn build_async(self) -> ArrowStreamReader<R> {
299        let projected_data_type = self
300            .file_metadata()
301            .root_data_type()
302            .project(&self.projection);
303        let projected_data_type_clone = projected_data_type.clone();
304        let schema_ref = self.schema();
305        let cursor = Cursor {
306            reader: self.reader,
307            file_metadata: self.file_metadata.clone(),
308            projected_data_type,
309            stripe_index: 0,
310            file_byte_range: self.file_byte_range,
311        };
312        ArrowStreamReader::new(
313            cursor,
314            self.batch_size,
315            schema_ref,
316            self.row_selection,
317            self.predicate,
318            projected_data_type_clone,
319            self.file_metadata,
320        )
321    }
322}