orc_rust/
async_arrow_reader.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::Formatter;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22
23use arrow::datatypes::SchemaRef;
24use arrow::error::ArrowError;
25use arrow::record_batch::RecordBatch;
26use futures::future::BoxFuture;
27use futures::{ready, Stream};
28use futures_util::FutureExt;
29
30use crate::array_decoder::NaiveStripeDecoder;
31use crate::arrow_reader::Cursor;
32use crate::error::Result;
33use crate::reader::metadata::read_metadata_async;
34use crate::reader::AsyncChunkReader;
35use crate::stripe::{Stripe, StripeMetadata};
36use crate::ArrowReaderBuilder;
37
38type BoxedDecoder = Box<dyn Iterator<Item = Result<RecordBatch>> + Send>;
39
40enum StreamState<T> {
41    /// At the start of a new row group, or the end of the file stream
42    Init,
43    /// Decoding a batch
44    Decoding(BoxedDecoder),
45    /// Reading data from input
46    Reading(BoxFuture<'static, Result<(StripeFactory<T>, Option<Stripe>)>>),
47    /// Error
48    Error,
49}
50
51impl<T> std::fmt::Debug for StreamState<T> {
52    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
53        match self {
54            StreamState::Init => write!(f, "StreamState::Init"),
55            StreamState::Decoding(_) => write!(f, "StreamState::Decoding"),
56            StreamState::Reading(_) => write!(f, "StreamState::Reading"),
57            StreamState::Error => write!(f, "StreamState::Error"),
58        }
59    }
60}
61
62impl<R: Send> From<Cursor<R>> for StripeFactory<R> {
63    fn from(c: Cursor<R>) -> Self {
64        Self {
65            inner: c,
66            is_end: false,
67        }
68    }
69}
70
71pub struct StripeFactory<R> {
72    inner: Cursor<R>,
73    is_end: bool,
74}
75
76pub struct ArrowStreamReader<R: AsyncChunkReader> {
77    factory: Option<Box<StripeFactory<R>>>,
78    batch_size: usize,
79    schema_ref: SchemaRef,
80    state: StreamState<R>,
81}
82
83impl<R: AsyncChunkReader + 'static> StripeFactory<R> {
84    async fn read_next_stripe_inner(&mut self, info: &StripeMetadata) -> Result<Stripe> {
85        let inner = &mut self.inner;
86
87        inner.stripe_index += 1;
88
89        Stripe::new_async(
90            &mut inner.reader,
91            &inner.file_metadata,
92            &inner.projected_data_type,
93            info,
94        )
95        .await
96    }
97
98    /// Read the next stripe from the file.
99    pub async fn read_next_stripe(mut self) -> Result<(Self, Option<Stripe>)> {
100        let info = self
101            .inner
102            .file_metadata
103            .stripe_metadatas()
104            .get(self.inner.stripe_index)
105            .cloned();
106
107        if let Some(info) = info {
108            if let Some(range) = self.inner.file_byte_range.clone() {
109                let offset = info.offset() as usize;
110                if !range.contains(&offset) {
111                    self.inner.stripe_index += 1;
112                    return Ok((self, None));
113                }
114            }
115            match self.read_next_stripe_inner(&info).await {
116                Ok(stripe) => Ok((self, Some(stripe))),
117                Err(err) => Err(err),
118            }
119        } else {
120            self.is_end = true;
121            Ok((self, None))
122        }
123    }
124}
125
126impl<R: AsyncChunkReader + 'static> ArrowStreamReader<R> {
127    pub(crate) fn new(cursor: Cursor<R>, batch_size: usize, schema_ref: SchemaRef) -> Self {
128        Self {
129            factory: Some(Box::new(cursor.into())),
130            batch_size,
131            schema_ref,
132            state: StreamState::Init,
133        }
134    }
135
136    /// Extracts the inner `StripeFactory` and `SchemaRef` from the `ArrowStreamReader`.
137    pub fn into_parts(self) -> (Option<Box<StripeFactory<R>>>, SchemaRef) {
138        (self.factory, self.schema_ref)
139    }
140
141    pub fn schema(&self) -> SchemaRef {
142        self.schema_ref.clone()
143    }
144
145    fn poll_next_inner(
146        mut self: Pin<&mut Self>,
147        cx: &mut Context<'_>,
148    ) -> Poll<Option<Result<RecordBatch>>> {
149        loop {
150            match &mut self.state {
151                StreamState::Decoding(decoder) => match decoder.next() {
152                    Some(Ok(batch)) => {
153                        return Poll::Ready(Some(Ok(batch)));
154                    }
155                    Some(Err(e)) => {
156                        self.state = StreamState::Error;
157                        return Poll::Ready(Some(Err(e)));
158                    }
159                    None => self.state = StreamState::Init,
160                },
161                StreamState::Init => {
162                    let factory = self.factory.take().expect("lost factory");
163                    if factory.is_end {
164                        return Poll::Ready(None);
165                    }
166
167                    let fut = factory.read_next_stripe().boxed();
168
169                    self.state = StreamState::Reading(fut)
170                }
171                StreamState::Reading(f) => match ready!(f.poll_unpin(cx)) {
172                    Ok((factory, Some(stripe))) => {
173                        self.factory = Some(Box::new(factory));
174                        match NaiveStripeDecoder::new(
175                            stripe,
176                            self.schema_ref.clone(),
177                            self.batch_size,
178                        ) {
179                            Ok(decoder) => {
180                                self.state = StreamState::Decoding(Box::new(decoder));
181                            }
182                            Err(e) => {
183                                self.state = StreamState::Error;
184                                return Poll::Ready(Some(Err(e)));
185                            }
186                        }
187                    }
188                    Ok((factory, None)) => {
189                        self.factory = Some(Box::new(factory));
190                        // All rows skipped, read next row group
191                        self.state = StreamState::Init;
192                    }
193                    Err(e) => {
194                        self.state = StreamState::Error;
195                        return Poll::Ready(Some(Err(e)));
196                    }
197                },
198                StreamState::Error => return Poll::Ready(None), // Ends the stream as error happens.
199            }
200        }
201    }
202}
203
204impl<R: AsyncChunkReader + 'static> Stream for ArrowStreamReader<R> {
205    type Item = Result<RecordBatch, ArrowError>;
206
207    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
208        self.poll_next_inner(cx)
209            .map_err(|e| ArrowError::ExternalError(Box::new(e)))
210    }
211}
212
213impl<R: AsyncChunkReader + 'static> ArrowReaderBuilder<R> {
214    pub async fn try_new_async(mut reader: R) -> Result<Self> {
215        let file_metadata = Arc::new(read_metadata_async(&mut reader).await?);
216        Ok(Self::new(reader, file_metadata))
217    }
218
219    pub fn build_async(self) -> ArrowStreamReader<R> {
220        let projected_data_type = self
221            .file_metadata()
222            .root_data_type()
223            .project(&self.projection);
224        let schema_ref = self.schema();
225        let cursor = Cursor {
226            reader: self.reader,
227            file_metadata: self.file_metadata,
228            projected_data_type,
229            stripe_index: 0,
230            file_byte_range: self.file_byte_range,
231        };
232        ArrowStreamReader::new(cursor, self.batch_size, schema_ref)
233    }
234}