1use std::fmt::Formatter;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22
23use arrow::datatypes::SchemaRef;
24use arrow::error::ArrowError;
25use arrow::record_batch::RecordBatch;
26use futures::future::BoxFuture;
27use futures::{ready, Stream};
28use futures_util::FutureExt;
29
30use crate::array_decoder::NaiveStripeDecoder;
31use crate::arrow_reader::Cursor;
32use crate::error::Result;
33use crate::predicate::Predicate;
34use crate::reader::metadata::read_metadata_async;
35use crate::reader::AsyncChunkReader;
36use crate::row_group_filter::evaluate_predicate;
37use crate::row_selection::RowSelection;
38use crate::schema::RootDataType;
39use crate::stripe::{Stripe, StripeMetadata};
40use crate::ArrowReaderBuilder;
41
42type BoxedDecoder = Box<dyn Iterator<Item = Result<RecordBatch>> + Send>;
43
44enum StreamState<T> {
45 Init,
47 Decoding(BoxedDecoder),
49 Reading(BoxFuture<'static, Result<(StripeFactory<T>, Option<Stripe>)>>),
51 Error,
53}
54
55impl<T> std::fmt::Debug for StreamState<T> {
56 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
57 match self {
58 StreamState::Init => write!(f, "StreamState::Init"),
59 StreamState::Decoding(_) => write!(f, "StreamState::Decoding"),
60 StreamState::Reading(_) => write!(f, "StreamState::Reading"),
61 StreamState::Error => write!(f, "StreamState::Error"),
62 }
63 }
64}
65
66impl<R: Send> From<Cursor<R>> for StripeFactory<R> {
67 fn from(c: Cursor<R>) -> Self {
68 Self {
69 inner: c,
70 is_end: false,
71 }
72 }
73}
74
75pub struct StripeFactory<R> {
76 inner: Cursor<R>,
77 is_end: bool,
78}
79
80pub struct ArrowStreamReader<R: AsyncChunkReader> {
81 factory: Option<Box<StripeFactory<R>>>,
82 batch_size: usize,
83 schema_ref: SchemaRef,
84 row_selection: Option<RowSelection>,
85 predicate: Option<Predicate>,
86 projected_data_type: RootDataType,
87 file_metadata: Arc<crate::reader::metadata::FileMetadata>,
88 state: StreamState<R>,
89}
90
91impl<R: AsyncChunkReader + 'static> StripeFactory<R> {
92 async fn read_next_stripe_inner(&mut self, info: &StripeMetadata) -> Result<Stripe> {
93 let inner = &mut self.inner;
94
95 inner.stripe_index += 1;
96
97 Stripe::new_async(
98 &mut inner.reader,
99 &inner.file_metadata,
100 &inner.projected_data_type,
101 info,
102 )
103 .await
104 }
105
106 pub async fn read_next_stripe(mut self) -> Result<(Self, Option<Stripe>)> {
108 let info = self
109 .inner
110 .file_metadata
111 .stripe_metadatas()
112 .get(self.inner.stripe_index)
113 .cloned();
114
115 if let Some(info) = info {
116 if let Some(range) = self.inner.file_byte_range.clone() {
117 let offset = info.offset() as usize;
118 if !range.contains(&offset) {
119 self.inner.stripe_index += 1;
120 return Ok((self, None));
121 }
122 }
123 match self.read_next_stripe_inner(&info).await {
124 Ok(stripe) => Ok((self, Some(stripe))),
125 Err(err) => Err(err),
126 }
127 } else {
128 self.is_end = true;
129 Ok((self, None))
130 }
131 }
132}
133
134impl<R: AsyncChunkReader + 'static> ArrowStreamReader<R> {
135 pub(crate) fn new(
136 cursor: Cursor<R>,
137 batch_size: usize,
138 schema_ref: SchemaRef,
139 row_selection: Option<RowSelection>,
140 predicate: Option<Predicate>,
141 projected_data_type: RootDataType,
142 file_metadata: Arc<crate::reader::metadata::FileMetadata>,
143 ) -> Self {
144 Self {
145 factory: Some(Box::new(cursor.into())),
146 batch_size,
147 schema_ref,
148 row_selection,
149 predicate,
150 projected_data_type,
151 file_metadata,
152 state: StreamState::Init,
153 }
154 }
155
156 pub fn into_parts(self) -> (Option<Box<StripeFactory<R>>>, SchemaRef) {
158 (self.factory, self.schema_ref)
159 }
160
161 pub fn schema(&self) -> SchemaRef {
162 self.schema_ref.clone()
163 }
164
165 fn poll_next_inner(
166 mut self: Pin<&mut Self>,
167 cx: &mut Context<'_>,
168 ) -> Poll<Option<Result<RecordBatch>>> {
169 loop {
170 match &mut self.state {
171 StreamState::Decoding(decoder) => match decoder.next() {
172 Some(Ok(batch)) => {
173 return Poll::Ready(Some(Ok(batch)));
174 }
175 Some(Err(e)) => {
176 self.state = StreamState::Error;
177 return Poll::Ready(Some(Err(e)));
178 }
179 None => self.state = StreamState::Init,
180 },
181 StreamState::Init => {
182 let factory = self.factory.take().expect("lost factory");
183 if factory.is_end {
184 return Poll::Ready(None);
185 }
186
187 let fut = factory.read_next_stripe().boxed();
188
189 self.state = StreamState::Reading(fut)
190 }
191 StreamState::Reading(f) => match ready!(f.poll_unpin(cx)) {
192 Ok((factory, Some(stripe))) => {
193 self.factory = Some(Box::new(factory));
194
195 let stripe_rows = stripe.number_of_rows();
196
197 let mut stripe_selection: Option<RowSelection> = None;
199 if let Some(ref predicate) = self.predicate {
200 match stripe.read_row_indexes(&self.file_metadata) {
202 Ok(row_index) => {
203 match evaluate_predicate(
205 predicate,
206 &row_index,
207 &self.projected_data_type,
208 ) {
209 Ok(row_group_filter) => {
210 let rows_per_group = self
212 .file_metadata
213 .row_index_stride()
214 .unwrap_or(10_000);
215 stripe_selection =
216 Some(RowSelection::from_row_group_filter(
217 &row_group_filter,
218 rows_per_group,
219 stripe_rows,
220 ));
221 }
222 Err(_) => {
223 stripe_selection =
226 Some(RowSelection::select_all(stripe_rows));
227 }
228 }
229 }
230 Err(_) => {
231 stripe_selection = Some(RowSelection::select_all(stripe_rows));
233 }
234 }
235 }
236
237 let mut final_selection = stripe_selection;
239 if let Some(ref mut existing_selection) = self.row_selection {
240 if existing_selection.row_count() > 0 {
241 let existing_for_stripe = existing_selection.split_off(stripe_rows);
242 final_selection = match final_selection {
243 Some(predicate_selection) => {
244 Some(existing_for_stripe.and_then(&predicate_selection))
246 }
247 None => Some(existing_for_stripe),
248 };
249 }
250 }
251
252 match NaiveStripeDecoder::new_with_selection(
253 stripe,
254 self.schema_ref.clone(),
255 self.batch_size,
256 final_selection,
257 ) {
258 Ok(decoder) => {
259 self.state = StreamState::Decoding(Box::new(decoder));
260 }
261 Err(e) => {
262 self.state = StreamState::Error;
263 return Poll::Ready(Some(Err(e)));
264 }
265 }
266 }
267 Ok((factory, None)) => {
268 self.factory = Some(Box::new(factory));
269 self.state = StreamState::Init;
271 }
272 Err(e) => {
273 self.state = StreamState::Error;
274 return Poll::Ready(Some(Err(e)));
275 }
276 },
277 StreamState::Error => return Poll::Ready(None), }
279 }
280 }
281}
282
283impl<R: AsyncChunkReader + 'static> Stream for ArrowStreamReader<R> {
284 type Item = Result<RecordBatch, ArrowError>;
285
286 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
287 self.poll_next_inner(cx)
288 .map_err(|e| ArrowError::ExternalError(Box::new(e)))
289 }
290}
291
292impl<R: AsyncChunkReader + 'static> ArrowReaderBuilder<R> {
293 pub async fn try_new_async(mut reader: R) -> Result<Self> {
294 let file_metadata = Arc::new(read_metadata_async(&mut reader).await?);
295 Ok(Self::new(reader, file_metadata))
296 }
297
298 pub fn build_async(self) -> ArrowStreamReader<R> {
299 let projected_data_type = self
300 .file_metadata()
301 .root_data_type()
302 .project(&self.projection);
303 let projected_data_type_clone = projected_data_type.clone();
304 let schema_ref = self.schema();
305 let cursor = Cursor {
306 reader: self.reader,
307 file_metadata: self.file_metadata.clone(),
308 projected_data_type,
309 stripe_index: 0,
310 file_byte_range: self.file_byte_range,
311 };
312 ArrowStreamReader::new(
313 cursor,
314 self.batch_size,
315 schema_ref,
316 self.row_selection,
317 self.predicate,
318 projected_data_type_clone,
319 self.file_metadata,
320 )
321 }
322}