Skip to main content

polars_arrow/io/ipc/read/
stream.rs

1use std::io::{Read, Seek};
2
3use arrow_format::ipc::planus::ReadAsRoot;
4use polars_error::{PolarsError, PolarsResult, polars_bail, polars_err};
5use polars_utils::bool::UnsafeBool;
6
7use super::super::CONTINUATION_MARKER;
8use super::common::*;
9use super::schema::deserialize_stream_metadata;
10use super::{Dictionaries, OutOfSpecKind};
11use crate::array::Array;
12use crate::datatypes::{ArrowSchema, Metadata};
13use crate::io::ipc::IpcSchema;
14use crate::record_batch::RecordBatchT;
15
16/// Metadata of an Arrow IPC stream, written at the start of the stream
17#[derive(Debug, Clone)]
18pub struct StreamMetadata {
19    /// The schema that is read from the stream's first message
20    pub schema: ArrowSchema,
21
22    /// The custom metadata that is read from the schema
23    pub custom_schema_metadata: Option<Metadata>,
24
25    /// The IPC version of the stream
26    pub version: arrow_format::ipc::MetadataVersion,
27
28    /// The IPC fields tracking dictionaries
29    pub ipc_schema: IpcSchema,
30}
31
32/// Reads the metadata of the stream
33pub fn read_stream_metadata(reader: &mut dyn std::io::Read) -> PolarsResult<StreamMetadata> {
34    // determine metadata length
35    let mut meta_size: [u8; 4] = [0; 4];
36    reader.read_exact(&mut meta_size)?;
37    let meta_length = {
38        // If a continuation marker is encountered, skip over it and read
39        // the size from the next four bytes.
40        if meta_size == CONTINUATION_MARKER {
41            reader.read_exact(&mut meta_size)?;
42        }
43        i32::from_le_bytes(meta_size)
44    };
45
46    let length: usize = meta_length
47        .try_into()
48        .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
49
50    let mut buffer = vec![];
51    buffer.try_reserve(length)?;
52    reader.take(length as u64).read_to_end(&mut buffer)?;
53
54    deserialize_stream_metadata(&buffer)
55}
56
57/// Encodes the stream's status after each read.
58///
59/// A stream is an iterator, and an iterator returns `Option<Item>`. The `Item`
60/// type in the [`StreamReader`] case is `StreamState`, which means that an Arrow
61/// stream may yield one of three values: (1) `None`, which signals that the stream
62/// is done; (2) [`StreamState::Some`], which signals that there was
63/// data waiting in the stream and we read it; and finally (3)
64/// [`Some(StreamState::Waiting)`], which means that the stream is still "live", it
65/// just doesn't hold any data right now.
66pub enum StreamState {
67    /// A live stream without data
68    Waiting,
69    /// Next item in the stream
70    Some(RecordBatchT<Box<dyn Array>>),
71}
72
73impl StreamState {
74    /// Return the data inside this wrapper.
75    ///
76    /// # Panics
77    ///
78    /// If the `StreamState` was `Waiting`.
79    pub fn unwrap(self) -> RecordBatchT<Box<dyn Array>> {
80        if let StreamState::Some(batch) = self {
81            batch
82        } else {
83            panic!("The batch is not available")
84        }
85    }
86}
87
88/// Reads the next item, yielding `None` if the stream is done,
89/// and a [`StreamState`] otherwise.
90fn read_next<R: Read + Seek>(
91    reader: &mut R,
92    metadata: &StreamMetadata,
93    dictionaries: &mut Dictionaries,
94    message_buffer: &mut Vec<u8>,
95    projection: &Option<ProjectionInfo>,
96    scratch: &mut Vec<u8>,
97    checked: UnsafeBool,
98) -> PolarsResult<Option<StreamState>> {
99    // determine metadata length
100    let mut meta_length: [u8; 4] = [0; 4];
101
102    match reader.read_exact(&mut meta_length) {
103        Ok(()) => (),
104        Err(e) => {
105            return if e.kind() == std::io::ErrorKind::UnexpectedEof {
106                // Handle EOF without the "0xFFFFFFFF 0x00000000"
107                // valid according to:
108                // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format
109                Ok(Some(StreamState::Waiting))
110            } else {
111                Err(PolarsError::from(e))
112            };
113        },
114    }
115
116    let meta_length = {
117        // If a continuation marker is encountered, skip over it and read
118        // the size from the next four bytes.
119        if meta_length == CONTINUATION_MARKER {
120            reader.read_exact(&mut meta_length)?;
121        }
122        i32::from_le_bytes(meta_length)
123    };
124
125    let meta_length: usize = meta_length
126        .try_into()
127        .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
128
129    if meta_length == 0 {
130        // the stream has ended, mark the reader as finished
131        return Ok(None);
132    }
133
134    message_buffer.clear();
135    message_buffer.try_reserve(meta_length)?;
136    reader
137        .by_ref()
138        .take(meta_length as u64)
139        .read_to_end(message_buffer)?;
140
141    let message = arrow_format::ipc::MessageRef::read_as_root(message_buffer.as_ref())
142        .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?;
143
144    let header = message
145        .header()
146        .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferHeader(err)))?
147        .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?;
148
149    let block_length: usize = message
150        .body_length()
151        .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBodyLength(err)))?
152        .try_into()
153        .map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?;
154
155    match header {
156        arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => {
157            let cur_pos = reader.stream_position()?;
158
159            let chunk = read_record_batch(
160                batch,
161                &metadata.schema,
162                &metadata.ipc_schema,
163                projection.as_ref().map(|x| x.columns.as_ref()),
164                None,
165                dictionaries,
166                metadata.version,
167                &mut (&mut *reader).take(block_length as u64),
168                0,
169                scratch,
170                checked,
171            );
172
173            let new_pos = reader.stream_position()?;
174            let read_size = new_pos - cur_pos;
175
176            reader.seek(std::io::SeekFrom::Current(
177                block_length as i64 - read_size as i64,
178            ))?;
179
180            if let Some(ProjectionInfo { map, .. }) = projection {
181                // re-order according to projection
182                chunk
183                    .map(|chunk| apply_projection(chunk, map))
184                    .map(|x| Some(StreamState::Some(x)))
185            } else {
186                chunk.map(|x| Some(StreamState::Some(x)))
187            }
188        },
189        arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => {
190            let cur_pos = reader.stream_position()?;
191
192            read_dictionary(
193                batch,
194                &metadata.schema,
195                &metadata.ipc_schema,
196                dictionaries,
197                &mut (&mut *reader).take(block_length as u64),
198                0,
199                scratch,
200                checked,
201            )?;
202
203            let new_pos = reader.stream_position()?;
204            let read_size = new_pos - cur_pos;
205
206            reader.seek(std::io::SeekFrom::Current(
207                block_length as i64 - read_size as i64,
208            ))?;
209
210            // read the next message until we encounter a RecordBatch message
211            read_next(
212                reader,
213                metadata,
214                dictionaries,
215                message_buffer,
216                projection,
217                scratch,
218                checked,
219            )
220        },
221        _ => polars_bail!(oos = OutOfSpecKind::UnexpectedMessageType),
222    }
223}
224
225/// Arrow Stream reader.
226///
227/// An [`Iterator`] over an Arrow stream that yields a result of [`StreamState`]s.
228/// This is the recommended way to read an arrow stream (by iterating over its data).
229///
230/// For a more thorough walkthrough consult [this example](https://github.com/jorgecarleitao/polars_arrow/tree/main/examples/ipc_pyarrow).
231pub struct StreamReader<R: Read> {
232    reader: R,
233    metadata: StreamMetadata,
234    dictionaries: Dictionaries,
235    finished: bool,
236    message_buffer: Vec<u8>,
237    projection: Option<ProjectionInfo>,
238    scratch: Vec<u8>,
239    checked: UnsafeBool,
240}
241
242impl<R: Read + Seek> StreamReader<R> {
243    /// Try to create a new stream reader
244    ///
245    /// The first message in the stream is the schema, the reader will fail if it does not
246    /// encounter a schema.
247    /// To check if the reader is done, use `is_finished(self)`
248    pub fn new(reader: R, metadata: StreamMetadata, projection: Option<Vec<usize>>) -> Self {
249        let projection =
250            projection.map(|projection| prepare_projection(&metadata.schema, projection));
251
252        Self {
253            reader,
254            metadata,
255            dictionaries: Default::default(),
256            finished: false,
257            message_buffer: Default::default(),
258            projection,
259            scratch: Default::default(),
260            checked: UnsafeBool::default(),
261        }
262    }
263
264    /// # Safety
265    /// Don't do expensive checks.
266    /// This means the data source has to be trusted to be correct.
267    pub unsafe fn unchecked(mut self) -> Self {
268        unsafe {
269            self.checked = UnsafeBool::new_false();
270        }
271        self
272    }
273
274    /// Return the schema of the stream
275    pub fn metadata(&self) -> &StreamMetadata {
276        &self.metadata
277    }
278
279    /// Return the schema of the file
280    pub fn schema(&self) -> &ArrowSchema {
281        self.projection
282            .as_ref()
283            .map(|x| &x.schema)
284            .unwrap_or(&self.metadata.schema)
285    }
286
287    /// Check if the stream is finished
288    pub fn is_finished(&self) -> bool {
289        self.finished
290    }
291
292    fn maybe_next(&mut self) -> PolarsResult<Option<StreamState>> {
293        if self.finished {
294            return Ok(None);
295        }
296        let batch = read_next(
297            &mut self.reader,
298            &self.metadata,
299            &mut self.dictionaries,
300            &mut self.message_buffer,
301            &self.projection,
302            &mut self.scratch,
303            self.checked,
304        )?;
305        if batch.is_none() {
306            self.finished = true;
307        }
308        Ok(batch)
309    }
310}
311
312impl<R: Read + Seek> Iterator for StreamReader<R> {
313    type Item = PolarsResult<StreamState>;
314
315    fn next(&mut self) -> Option<Self::Item> {
316        self.maybe_next().transpose()
317    }
318}