orc_rust/
async_arrow_reader.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::Formatter;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22
23use arrow::datatypes::SchemaRef;
24use arrow::error::ArrowError;
25use arrow::record_batch::RecordBatch;
26use futures::future::BoxFuture;
27use futures::{ready, Stream};
28use futures_util::FutureExt;
29
30use crate::array_decoder::NaiveStripeDecoder;
31use crate::arrow_reader::Cursor;
32use crate::error::Result;
33use crate::reader::metadata::read_metadata_async;
34use crate::reader::AsyncChunkReader;
35use crate::row_selection::RowSelection;
36use crate::stripe::{Stripe, StripeMetadata};
37use crate::ArrowReaderBuilder;
38
39type BoxedDecoder = Box<dyn Iterator<Item = Result<RecordBatch>> + Send>;
40
41enum StreamState<T> {
42    /// At the start of a new row group, or the end of the file stream
43    Init,
44    /// Decoding a batch
45    Decoding(BoxedDecoder),
46    /// Reading data from input
47    Reading(BoxFuture<'static, Result<(StripeFactory<T>, Option<Stripe>)>>),
48    /// Error
49    Error,
50}
51
52impl<T> std::fmt::Debug for StreamState<T> {
53    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54        match self {
55            StreamState::Init => write!(f, "StreamState::Init"),
56            StreamState::Decoding(_) => write!(f, "StreamState::Decoding"),
57            StreamState::Reading(_) => write!(f, "StreamState::Reading"),
58            StreamState::Error => write!(f, "StreamState::Error"),
59        }
60    }
61}
62
63impl<R: Send> From<Cursor<R>> for StripeFactory<R> {
64    fn from(c: Cursor<R>) -> Self {
65        Self {
66            inner: c,
67            is_end: false,
68        }
69    }
70}
71
72pub struct StripeFactory<R> {
73    inner: Cursor<R>,
74    is_end: bool,
75}
76
77pub struct ArrowStreamReader<R: AsyncChunkReader> {
78    factory: Option<Box<StripeFactory<R>>>,
79    batch_size: usize,
80    schema_ref: SchemaRef,
81    row_selection: Option<RowSelection>,
82    state: StreamState<R>,
83}
84
85impl<R: AsyncChunkReader + 'static> StripeFactory<R> {
86    async fn read_next_stripe_inner(&mut self, info: &StripeMetadata) -> Result<Stripe> {
87        let inner = &mut self.inner;
88
89        inner.stripe_index += 1;
90
91        Stripe::new_async(
92            &mut inner.reader,
93            &inner.file_metadata,
94            &inner.projected_data_type,
95            info,
96        )
97        .await
98    }
99
100    /// Read the next stripe from the file.
101    pub async fn read_next_stripe(mut self) -> Result<(Self, Option<Stripe>)> {
102        let info = self
103            .inner
104            .file_metadata
105            .stripe_metadatas()
106            .get(self.inner.stripe_index)
107            .cloned();
108
109        if let Some(info) = info {
110            if let Some(range) = self.inner.file_byte_range.clone() {
111                let offset = info.offset() as usize;
112                if !range.contains(&offset) {
113                    self.inner.stripe_index += 1;
114                    return Ok((self, None));
115                }
116            }
117            match self.read_next_stripe_inner(&info).await {
118                Ok(stripe) => Ok((self, Some(stripe))),
119                Err(err) => Err(err),
120            }
121        } else {
122            self.is_end = true;
123            Ok((self, None))
124        }
125    }
126}
127
128impl<R: AsyncChunkReader + 'static> ArrowStreamReader<R> {
129    pub(crate) fn new(
130        cursor: Cursor<R>,
131        batch_size: usize,
132        schema_ref: SchemaRef,
133        row_selection: Option<RowSelection>,
134    ) -> Self {
135        Self {
136            factory: Some(Box::new(cursor.into())),
137            batch_size,
138            schema_ref,
139            row_selection,
140            state: StreamState::Init,
141        }
142    }
143
144    /// Extracts the inner `StripeFactory` and `SchemaRef` from the `ArrowStreamReader`.
145    pub fn into_parts(self) -> (Option<Box<StripeFactory<R>>>, SchemaRef) {
146        (self.factory, self.schema_ref)
147    }
148
149    pub fn schema(&self) -> SchemaRef {
150        self.schema_ref.clone()
151    }
152
153    fn poll_next_inner(
154        mut self: Pin<&mut Self>,
155        cx: &mut Context<'_>,
156    ) -> Poll<Option<Result<RecordBatch>>> {
157        loop {
158            match &mut self.state {
159                StreamState::Decoding(decoder) => match decoder.next() {
160                    Some(Ok(batch)) => {
161                        return Poll::Ready(Some(Ok(batch)));
162                    }
163                    Some(Err(e)) => {
164                        self.state = StreamState::Error;
165                        return Poll::Ready(Some(Err(e)));
166                    }
167                    None => self.state = StreamState::Init,
168                },
169                StreamState::Init => {
170                    let factory = self.factory.take().expect("lost factory");
171                    if factory.is_end {
172                        return Poll::Ready(None);
173                    }
174
175                    let fut = factory.read_next_stripe().boxed();
176
177                    self.state = StreamState::Reading(fut)
178                }
179                StreamState::Reading(f) => match ready!(f.poll_unpin(cx)) {
180                    Ok((factory, Some(stripe))) => {
181                        self.factory = Some(Box::new(factory));
182
183                        // Split off the row selection for this stripe
184                        let stripe_rows = stripe.number_of_rows();
185                        let selection = self.row_selection.as_mut().and_then(|s| {
186                            if s.row_count() > 0 {
187                                Some(s.split_off(stripe_rows))
188                            } else {
189                                None
190                            }
191                        });
192
193                        match NaiveStripeDecoder::new_with_selection(
194                            stripe,
195                            self.schema_ref.clone(),
196                            self.batch_size,
197                            selection,
198                        ) {
199                            Ok(decoder) => {
200                                self.state = StreamState::Decoding(Box::new(decoder));
201                            }
202                            Err(e) => {
203                                self.state = StreamState::Error;
204                                return Poll::Ready(Some(Err(e)));
205                            }
206                        }
207                    }
208                    Ok((factory, None)) => {
209                        self.factory = Some(Box::new(factory));
210                        // All rows skipped, read next row group
211                        self.state = StreamState::Init;
212                    }
213                    Err(e) => {
214                        self.state = StreamState::Error;
215                        return Poll::Ready(Some(Err(e)));
216                    }
217                },
218                StreamState::Error => return Poll::Ready(None), // Ends the stream as error happens.
219            }
220        }
221    }
222}
223
224impl<R: AsyncChunkReader + 'static> Stream for ArrowStreamReader<R> {
225    type Item = Result<RecordBatch, ArrowError>;
226
227    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
228        self.poll_next_inner(cx)
229            .map_err(|e| ArrowError::ExternalError(Box::new(e)))
230    }
231}
232
233impl<R: AsyncChunkReader + 'static> ArrowReaderBuilder<R> {
234    pub async fn try_new_async(mut reader: R) -> Result<Self> {
235        let file_metadata = Arc::new(read_metadata_async(&mut reader).await?);
236        Ok(Self::new(reader, file_metadata))
237    }
238
239    pub fn build_async(self) -> ArrowStreamReader<R> {
240        let projected_data_type = self
241            .file_metadata()
242            .root_data_type()
243            .project(&self.projection);
244        let schema_ref = self.schema();
245        let cursor = Cursor {
246            reader: self.reader,
247            file_metadata: self.file_metadata,
248            projected_data_type,
249            stripe_index: 0,
250            file_byte_range: self.file_byte_range,
251        };
252        ArrowStreamReader::new(cursor, self.batch_size, schema_ref, self.row_selection)
253    }
254}