orc_rust/
async_arrow_reader.rs1use std::fmt::Formatter;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22
23use arrow::datatypes::SchemaRef;
24use arrow::error::ArrowError;
25use arrow::record_batch::RecordBatch;
26use futures::future::BoxFuture;
27use futures::{ready, Stream};
28use futures_util::FutureExt;
29
30use crate::array_decoder::NaiveStripeDecoder;
31use crate::arrow_reader::Cursor;
32use crate::error::Result;
33use crate::reader::metadata::read_metadata_async;
34use crate::reader::AsyncChunkReader;
35use crate::row_selection::RowSelection;
36use crate::stripe::{Stripe, StripeMetadata};
37use crate::ArrowReaderBuilder;
38
39type BoxedDecoder = Box<dyn Iterator<Item = Result<RecordBatch>> + Send>;
40
41enum StreamState<T> {
42 Init,
44 Decoding(BoxedDecoder),
46 Reading(BoxFuture<'static, Result<(StripeFactory<T>, Option<Stripe>)>>),
48 Error,
50}
51
52impl<T> std::fmt::Debug for StreamState<T> {
53 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54 match self {
55 StreamState::Init => write!(f, "StreamState::Init"),
56 StreamState::Decoding(_) => write!(f, "StreamState::Decoding"),
57 StreamState::Reading(_) => write!(f, "StreamState::Reading"),
58 StreamState::Error => write!(f, "StreamState::Error"),
59 }
60 }
61}
62
63impl<R: Send> From<Cursor<R>> for StripeFactory<R> {
64 fn from(c: Cursor<R>) -> Self {
65 Self {
66 inner: c,
67 is_end: false,
68 }
69 }
70}
71
72pub struct StripeFactory<R> {
73 inner: Cursor<R>,
74 is_end: bool,
75}
76
77pub struct ArrowStreamReader<R: AsyncChunkReader> {
78 factory: Option<Box<StripeFactory<R>>>,
79 batch_size: usize,
80 schema_ref: SchemaRef,
81 row_selection: Option<RowSelection>,
82 state: StreamState<R>,
83}
84
85impl<R: AsyncChunkReader + 'static> StripeFactory<R> {
86 async fn read_next_stripe_inner(&mut self, info: &StripeMetadata) -> Result<Stripe> {
87 let inner = &mut self.inner;
88
89 inner.stripe_index += 1;
90
91 Stripe::new_async(
92 &mut inner.reader,
93 &inner.file_metadata,
94 &inner.projected_data_type,
95 info,
96 )
97 .await
98 }
99
100 pub async fn read_next_stripe(mut self) -> Result<(Self, Option<Stripe>)> {
102 let info = self
103 .inner
104 .file_metadata
105 .stripe_metadatas()
106 .get(self.inner.stripe_index)
107 .cloned();
108
109 if let Some(info) = info {
110 if let Some(range) = self.inner.file_byte_range.clone() {
111 let offset = info.offset() as usize;
112 if !range.contains(&offset) {
113 self.inner.stripe_index += 1;
114 return Ok((self, None));
115 }
116 }
117 match self.read_next_stripe_inner(&info).await {
118 Ok(stripe) => Ok((self, Some(stripe))),
119 Err(err) => Err(err),
120 }
121 } else {
122 self.is_end = true;
123 Ok((self, None))
124 }
125 }
126}
127
128impl<R: AsyncChunkReader + 'static> ArrowStreamReader<R> {
129 pub(crate) fn new(
130 cursor: Cursor<R>,
131 batch_size: usize,
132 schema_ref: SchemaRef,
133 row_selection: Option<RowSelection>,
134 ) -> Self {
135 Self {
136 factory: Some(Box::new(cursor.into())),
137 batch_size,
138 schema_ref,
139 row_selection,
140 state: StreamState::Init,
141 }
142 }
143
144 pub fn into_parts(self) -> (Option<Box<StripeFactory<R>>>, SchemaRef) {
146 (self.factory, self.schema_ref)
147 }
148
149 pub fn schema(&self) -> SchemaRef {
150 self.schema_ref.clone()
151 }
152
153 fn poll_next_inner(
154 mut self: Pin<&mut Self>,
155 cx: &mut Context<'_>,
156 ) -> Poll<Option<Result<RecordBatch>>> {
157 loop {
158 match &mut self.state {
159 StreamState::Decoding(decoder) => match decoder.next() {
160 Some(Ok(batch)) => {
161 return Poll::Ready(Some(Ok(batch)));
162 }
163 Some(Err(e)) => {
164 self.state = StreamState::Error;
165 return Poll::Ready(Some(Err(e)));
166 }
167 None => self.state = StreamState::Init,
168 },
169 StreamState::Init => {
170 let factory = self.factory.take().expect("lost factory");
171 if factory.is_end {
172 return Poll::Ready(None);
173 }
174
175 let fut = factory.read_next_stripe().boxed();
176
177 self.state = StreamState::Reading(fut)
178 }
179 StreamState::Reading(f) => match ready!(f.poll_unpin(cx)) {
180 Ok((factory, Some(stripe))) => {
181 self.factory = Some(Box::new(factory));
182
183 let stripe_rows = stripe.number_of_rows();
185 let selection = self.row_selection.as_mut().and_then(|s| {
186 if s.row_count() > 0 {
187 Some(s.split_off(stripe_rows))
188 } else {
189 None
190 }
191 });
192
193 match NaiveStripeDecoder::new_with_selection(
194 stripe,
195 self.schema_ref.clone(),
196 self.batch_size,
197 selection,
198 ) {
199 Ok(decoder) => {
200 self.state = StreamState::Decoding(Box::new(decoder));
201 }
202 Err(e) => {
203 self.state = StreamState::Error;
204 return Poll::Ready(Some(Err(e)));
205 }
206 }
207 }
208 Ok((factory, None)) => {
209 self.factory = Some(Box::new(factory));
210 self.state = StreamState::Init;
212 }
213 Err(e) => {
214 self.state = StreamState::Error;
215 return Poll::Ready(Some(Err(e)));
216 }
217 },
218 StreamState::Error => return Poll::Ready(None), }
220 }
221 }
222}
223
224impl<R: AsyncChunkReader + 'static> Stream for ArrowStreamReader<R> {
225 type Item = Result<RecordBatch, ArrowError>;
226
227 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
228 self.poll_next_inner(cx)
229 .map_err(|e| ArrowError::ExternalError(Box::new(e)))
230 }
231}
232
233impl<R: AsyncChunkReader + 'static> ArrowReaderBuilder<R> {
234 pub async fn try_new_async(mut reader: R) -> Result<Self> {
235 let file_metadata = Arc::new(read_metadata_async(&mut reader).await?);
236 Ok(Self::new(reader, file_metadata))
237 }
238
239 pub fn build_async(self) -> ArrowStreamReader<R> {
240 let projected_data_type = self
241 .file_metadata()
242 .root_data_type()
243 .project(&self.projection);
244 let schema_ref = self.schema();
245 let cursor = Cursor {
246 reader: self.reader,
247 file_metadata: self.file_metadata,
248 projected_data_type,
249 stripe_index: 0,
250 file_byte_range: self.file_byte_range,
251 };
252 ArrowStreamReader::new(cursor, self.batch_size, schema_ref, self.row_selection)
253 }
254}