orc_rust/
async_arrow_reader.rs1use std::fmt::Formatter;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22
23use arrow::datatypes::SchemaRef;
24use arrow::error::ArrowError;
25use arrow::record_batch::RecordBatch;
26use futures::future::BoxFuture;
27use futures::{ready, Stream};
28use futures_util::FutureExt;
29
30use crate::array_decoder::NaiveStripeDecoder;
31use crate::arrow_reader::Cursor;
32use crate::error::Result;
33use crate::reader::metadata::read_metadata_async;
34use crate::reader::AsyncChunkReader;
35use crate::stripe::{Stripe, StripeMetadata};
36use crate::ArrowReaderBuilder;
37
38type BoxedDecoder = Box<dyn Iterator<Item = Result<RecordBatch>> + Send>;
39
40enum StreamState<T> {
41 Init,
43 Decoding(BoxedDecoder),
45 Reading(BoxFuture<'static, Result<(StripeFactory<T>, Option<Stripe>)>>),
47 Error,
49}
50
51impl<T> std::fmt::Debug for StreamState<T> {
52 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
53 match self {
54 StreamState::Init => write!(f, "StreamState::Init"),
55 StreamState::Decoding(_) => write!(f, "StreamState::Decoding"),
56 StreamState::Reading(_) => write!(f, "StreamState::Reading"),
57 StreamState::Error => write!(f, "StreamState::Error"),
58 }
59 }
60}
61
62impl<R: Send> From<Cursor<R>> for StripeFactory<R> {
63 fn from(c: Cursor<R>) -> Self {
64 Self {
65 inner: c,
66 is_end: false,
67 }
68 }
69}
70
71pub struct StripeFactory<R> {
72 inner: Cursor<R>,
73 is_end: bool,
74}
75
76pub struct ArrowStreamReader<R: AsyncChunkReader> {
77 factory: Option<Box<StripeFactory<R>>>,
78 batch_size: usize,
79 schema_ref: SchemaRef,
80 state: StreamState<R>,
81}
82
83impl<R: AsyncChunkReader + 'static> StripeFactory<R> {
84 async fn read_next_stripe_inner(&mut self, info: &StripeMetadata) -> Result<Stripe> {
85 let inner = &mut self.inner;
86
87 inner.stripe_index += 1;
88
89 Stripe::new_async(
90 &mut inner.reader,
91 &inner.file_metadata,
92 &inner.projected_data_type,
93 info,
94 )
95 .await
96 }
97
98 pub async fn read_next_stripe(mut self) -> Result<(Self, Option<Stripe>)> {
100 let info = self
101 .inner
102 .file_metadata
103 .stripe_metadatas()
104 .get(self.inner.stripe_index)
105 .cloned();
106
107 if let Some(info) = info {
108 if let Some(range) = self.inner.file_byte_range.clone() {
109 let offset = info.offset() as usize;
110 if !range.contains(&offset) {
111 self.inner.stripe_index += 1;
112 return Ok((self, None));
113 }
114 }
115 match self.read_next_stripe_inner(&info).await {
116 Ok(stripe) => Ok((self, Some(stripe))),
117 Err(err) => Err(err),
118 }
119 } else {
120 self.is_end = true;
121 Ok((self, None))
122 }
123 }
124}
125
126impl<R: AsyncChunkReader + 'static> ArrowStreamReader<R> {
127 pub(crate) fn new(cursor: Cursor<R>, batch_size: usize, schema_ref: SchemaRef) -> Self {
128 Self {
129 factory: Some(Box::new(cursor.into())),
130 batch_size,
131 schema_ref,
132 state: StreamState::Init,
133 }
134 }
135
136 pub fn into_parts(self) -> (Option<Box<StripeFactory<R>>>, SchemaRef) {
138 (self.factory, self.schema_ref)
139 }
140
141 pub fn schema(&self) -> SchemaRef {
142 self.schema_ref.clone()
143 }
144
145 fn poll_next_inner(
146 mut self: Pin<&mut Self>,
147 cx: &mut Context<'_>,
148 ) -> Poll<Option<Result<RecordBatch>>> {
149 loop {
150 match &mut self.state {
151 StreamState::Decoding(decoder) => match decoder.next() {
152 Some(Ok(batch)) => {
153 return Poll::Ready(Some(Ok(batch)));
154 }
155 Some(Err(e)) => {
156 self.state = StreamState::Error;
157 return Poll::Ready(Some(Err(e)));
158 }
159 None => self.state = StreamState::Init,
160 },
161 StreamState::Init => {
162 let factory = self.factory.take().expect("lost factory");
163 if factory.is_end {
164 return Poll::Ready(None);
165 }
166
167 let fut = factory.read_next_stripe().boxed();
168
169 self.state = StreamState::Reading(fut)
170 }
171 StreamState::Reading(f) => match ready!(f.poll_unpin(cx)) {
172 Ok((factory, Some(stripe))) => {
173 self.factory = Some(Box::new(factory));
174 match NaiveStripeDecoder::new(
175 stripe,
176 self.schema_ref.clone(),
177 self.batch_size,
178 ) {
179 Ok(decoder) => {
180 self.state = StreamState::Decoding(Box::new(decoder));
181 }
182 Err(e) => {
183 self.state = StreamState::Error;
184 return Poll::Ready(Some(Err(e)));
185 }
186 }
187 }
188 Ok((factory, None)) => {
189 self.factory = Some(Box::new(factory));
190 self.state = StreamState::Init;
192 }
193 Err(e) => {
194 self.state = StreamState::Error;
195 return Poll::Ready(Some(Err(e)));
196 }
197 },
198 StreamState::Error => return Poll::Ready(None), }
200 }
201 }
202}
203
204impl<R: AsyncChunkReader + 'static> Stream for ArrowStreamReader<R> {
205 type Item = Result<RecordBatch, ArrowError>;
206
207 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
208 self.poll_next_inner(cx)
209 .map_err(|e| ArrowError::ExternalError(Box::new(e)))
210 }
211}
212
213impl<R: AsyncChunkReader + 'static> ArrowReaderBuilder<R> {
214 pub async fn try_new_async(mut reader: R) -> Result<Self> {
215 let file_metadata = Arc::new(read_metadata_async(&mut reader).await?);
216 Ok(Self::new(reader, file_metadata))
217 }
218
219 pub fn build_async(self) -> ArrowStreamReader<R> {
220 let projected_data_type = self
221 .file_metadata()
222 .root_data_type()
223 .project(&self.projection);
224 let schema_ref = self.schema();
225 let cursor = Cursor {
226 reader: self.reader,
227 file_metadata: self.file_metadata,
228 projected_data_type,
229 stripe_index: 0,
230 file_byte_range: self.file_byte_range,
231 };
232 ArrowStreamReader::new(cursor, self.batch_size, schema_ref)
233 }
234}