Skip to main content

arrow_ipc/reader/
stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::fmt::Debug;
20use std::sync::Arc;
21
22use arrow_array::{ArrayRef, RecordBatch};
23use arrow_buffer::{Buffer, MutableBuffer};
24use arrow_data::UnsafeFlag;
25use arrow_schema::{ArrowError, SchemaRef};
26
27use crate::convert::MessageBuffer;
28use crate::reader::{RecordBatchDecoder, read_dictionary_impl};
29use crate::{CONTINUATION_MARKER, MessageHeader};
30
31/// A low-level interface for reading [`RecordBatch`] data from a stream of bytes
32///
33/// See [StreamReader](crate::reader::StreamReader) for a higher-level interface
34#[derive(Debug, Default)]
35pub struct StreamDecoder {
36    /// The schema of this decoder, if read
37    schema: Option<SchemaRef>,
38    /// Lookup table for dictionaries by ID
39    dictionaries: HashMap<i64, ArrayRef>,
40    /// The decoder state
41    state: DecoderState,
42    /// A scratch buffer when a read is split across multiple `Buffer`
43    buf: MutableBuffer,
44    /// Whether or not array data in input buffers are required to be aligned
45    require_alignment: bool,
46    /// Should validation be skipped when reading data? Defaults to false.
47    ///
48    /// See [`StreamDecoder::with_skip_validation`] for details.
49    ///
50    skip_validation: UnsafeFlag,
51}
52
53#[derive(Debug)]
54enum DecoderState {
55    /// Decoding the message header
56    Header {
57        /// Temporary buffer
58        buf: [u8; 4],
59        /// Number of bytes read into buf
60        read: u8,
61        /// If we have read a continuation token
62        continuation: bool,
63    },
64    /// Decoding the message flatbuffer
65    Message {
66        /// The size of the message flatbuffer
67        size: u32,
68    },
69    /// Decoding the message body
70    Body {
71        /// The message flatbuffer
72        message: MessageBuffer,
73    },
74    /// Reached the end of the stream
75    Finished,
76}
77
78impl Default for DecoderState {
79    fn default() -> Self {
80        Self::Header {
81            buf: [0; 4],
82            read: 0,
83            continuation: false,
84        }
85    }
86}
87
88impl StreamDecoder {
89    /// Create a new [`StreamDecoder`]
90    pub fn new() -> Self {
91        Self::default()
92    }
93
94    /// Specifies whether or not array data in input buffers is required to be properly aligned.
95    ///
96    /// If `require_alignment` is true, this decoder will return an error if any array data in the
97    /// input `buf` is not properly aligned.
98    /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct
99    /// [`arrow_data::ArrayData`].
100    ///
101    /// If `require_alignment` is false (the default), this decoder will automatically allocate a
102    /// new aligned buffer and copy over the data if any array data in the input `buf` is not
103    /// properly aligned. (Properly aligned array data will remain zero-copy.)
104    /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct
105    /// [`arrow_data::ArrayData`].
106    pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
107        self.require_alignment = require_alignment;
108        self
109    }
110
111    /// Return the schema if decoded, else None.
112    pub fn schema(&self) -> Option<SchemaRef> {
113        self.schema.as_ref().map(|schema| schema.clone())
114    }
115
116    /// Specifies if validation should be skipped when reading data (defaults to `false`)
117    ///
118    /// # Safety
119    ///
120    /// This flag must only be set to `true` when you trust the input data and are
121    /// sure the data you are reading is valid Arrow IPC stream data, otherwise
122    /// undefined behavior may result.
123    ///
124    /// For example, DataFusion uses this when reading spill files it wrote itself.
125    pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
126        unsafe { self.skip_validation.set(skip_validation) };
127        self
128    }
129
130    /// Try to read the next [`RecordBatch`] from the provided [`Buffer`]
131    ///
132    /// [`Buffer::advance`] will be called on `buffer` for any consumed bytes.
133    ///
134    /// The push-based interface facilitates integration with sources that yield arbitrarily
135    /// delimited bytes ranges, such as a chunked byte stream received from object storage
136    ///
137    /// ```
138    /// # use arrow_array::RecordBatch;
139    /// # use arrow_buffer::Buffer;
140    /// # use arrow_ipc::reader::StreamDecoder;
141    /// # use arrow_schema::ArrowError;
142    /// #
143    /// fn print_stream<I>(src: impl Iterator<Item = Buffer>) -> Result<(), ArrowError> {
144    ///     let mut decoder = StreamDecoder::new();
145    ///     for mut x in src {
146    ///         while !x.is_empty() {
147    ///             if let Some(x) = decoder.decode(&mut x)? {
148    ///                 println!("{x:?}");
149    ///             }
150    ///             if let Some(schema) = decoder.schema() {
151    ///                 println!("Schema: {schema:?}");
152    ///             }
153    ///         }
154    ///     }
155    ///     decoder.finish().unwrap();
156    ///     Ok(())
157    /// }
158    /// ```
159    pub fn decode(&mut self, buffer: &mut Buffer) -> Result<Option<RecordBatch>, ArrowError> {
160        while !buffer.is_empty() {
161            match &mut self.state {
162                DecoderState::Header {
163                    buf,
164                    read,
165                    continuation,
166                } => {
167                    let offset_buf = &mut buf[*read as usize..];
168                    let to_read = buffer.len().min(offset_buf.len());
169                    offset_buf[..to_read].copy_from_slice(&buffer[..to_read]);
170                    *read += to_read as u8;
171                    buffer.advance(to_read);
172                    if *read == 4 {
173                        if !*continuation && buf == &CONTINUATION_MARKER {
174                            *continuation = true;
175                            *read = 0;
176                            continue;
177                        }
178                        let size = u32::from_le_bytes(*buf);
179
180                        if size == 0 {
181                            self.state = DecoderState::Finished;
182                            continue;
183                        }
184                        self.state = DecoderState::Message { size };
185                    }
186                }
187                DecoderState::Message { size } => {
188                    let len = *size as usize;
189                    if self.buf.is_empty() && buffer.len() > len {
190                        let message = MessageBuffer::try_new(buffer.slice_with_length(0, len))?;
191                        self.state = DecoderState::Body { message };
192                        buffer.advance(len);
193                        continue;
194                    }
195
196                    let to_read = buffer.len().min(len - self.buf.len());
197                    self.buf.extend_from_slice(&buffer[..to_read]);
198                    buffer.advance(to_read);
199                    if self.buf.len() == len {
200                        let message = MessageBuffer::try_new(std::mem::take(&mut self.buf).into())?;
201                        self.state = DecoderState::Body { message };
202                    }
203                }
204                DecoderState::Body { message } => {
205                    let message = message.as_ref();
206                    let body_length = message.bodyLength() as usize;
207
208                    let body = if self.buf.is_empty() && buffer.len() >= body_length {
209                        let body = buffer.slice_with_length(0, body_length);
210                        buffer.advance(body_length);
211                        body
212                    } else {
213                        let to_read = buffer.len().min(body_length - self.buf.len());
214                        self.buf.extend_from_slice(&buffer[..to_read]);
215                        buffer.advance(to_read);
216
217                        if self.buf.len() != body_length {
218                            continue;
219                        }
220                        std::mem::take(&mut self.buf).into()
221                    };
222
223                    let version = message.version();
224                    match message.header_type() {
225                        MessageHeader::Schema => {
226                            if self.schema.is_some() {
227                                return Err(ArrowError::IpcError(
228                                    "Not expecting a schema when messages are read".to_string(),
229                                ));
230                            }
231
232                            let ipc_schema = message.header_as_schema().unwrap();
233                            let schema = crate::convert::fb_to_schema(ipc_schema);
234                            self.state = DecoderState::default();
235                            self.schema = Some(Arc::new(schema));
236                        }
237                        MessageHeader::RecordBatch => {
238                            let batch = message.header_as_record_batch().unwrap();
239                            let schema = self.schema.clone().ok_or_else(|| {
240                                ArrowError::IpcError("Missing schema".to_string())
241                            })?;
242                            let batch = RecordBatchDecoder::try_new(
243                                &body,
244                                batch,
245                                schema,
246                                &self.dictionaries,
247                                &version,
248                            )?
249                            .with_require_alignment(self.require_alignment)
250                            .read_record_batch()?;
251                            self.state = DecoderState::default();
252                            return Ok(Some(batch));
253                        }
254                        MessageHeader::DictionaryBatch => {
255                            let dictionary = message.header_as_dictionary_batch().unwrap();
256                            let schema = self.schema.as_deref().ok_or_else(|| {
257                                ArrowError::IpcError("Missing schema".to_string())
258                            })?;
259                            read_dictionary_impl(
260                                &body,
261                                dictionary,
262                                schema,
263                                &mut self.dictionaries,
264                                &version,
265                                self.require_alignment,
266                                self.skip_validation.clone(),
267                            )?;
268                            self.state = DecoderState::default();
269                        }
270                        MessageHeader::NONE => {
271                            self.state = DecoderState::default();
272                        }
273                        t => {
274                            return Err(ArrowError::IpcError(format!(
275                                "Message type unsupported by StreamDecoder: {t:?}"
276                            )));
277                        }
278                    }
279                }
280                DecoderState::Finished => {
281                    return Err(ArrowError::IpcError("Unexpected EOS".to_string()));
282                }
283            }
284        }
285        Ok(None)
286    }
287
288    /// Signal the end of stream
289    ///
290    /// Returns an error if any partial data remains in the stream
291    pub fn finish(&mut self) -> Result<(), ArrowError> {
292        match self.state {
293            DecoderState::Finished
294            | DecoderState::Header {
295                read: 0,
296                continuation: false,
297                ..
298            } => Ok(()),
299            _ => Err(ArrowError::IpcError("Unexpected End of Stream".to_string())),
300        }
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use crate::writer::{IpcWriteOptions, StreamWriter};
308    use arrow_array::{
309        DictionaryArray, Int32Array, Int64Array, RecordBatch, RunArray, types::Int32Type,
310    };
311    use arrow_schema::{DataType, Field, Schema};
312
313    // Further tests in arrow-integration-testing/tests/ipc_reader.rs
314
315    #[test]
316    fn test_eos() {
317        let schema = Arc::new(Schema::new(vec![
318            Field::new("int32", DataType::Int32, false),
319            Field::new("int64", DataType::Int64, false),
320        ]));
321
322        let input = RecordBatch::try_new(
323            schema.clone(),
324            vec![
325                Arc::new(Int32Array::from(vec![1, 2, 3])) as _,
326                Arc::new(Int64Array::from(vec![1, 2, 3])) as _,
327            ],
328        )
329        .unwrap();
330
331        let mut buf = Vec::with_capacity(1024);
332        let mut s = StreamWriter::try_new(&mut buf, &schema).unwrap();
333        s.write(&input).unwrap();
334        s.finish().unwrap();
335        drop(s);
336
337        let buffer = Buffer::from_vec(buf);
338
339        let mut b = buffer.slice_with_length(0, buffer.len() - 1);
340        let mut decoder = StreamDecoder::new();
341        let output = decoder.decode(&mut b).unwrap().unwrap();
342        assert_eq!(output, input);
343        assert_eq!(b.len(), 7); // 8 byte EOS truncated by 1 byte
344        assert!(decoder.decode(&mut b).unwrap().is_none());
345
346        let err = decoder.finish().unwrap_err().to_string();
347        assert_eq!(err, "Ipc error: Unexpected End of Stream");
348    }
349
350    #[test]
351    fn test_schema() {
352        let schema = Arc::new(Schema::new(vec![
353            Field::new("int32", DataType::Int32, false),
354            Field::new("int64", DataType::Int64, false),
355        ]));
356
357        let mut buf = Vec::with_capacity(1024);
358        let mut s = StreamWriter::try_new(&mut buf, &schema).unwrap();
359        s.finish().unwrap();
360        drop(s);
361
362        let buffer = Buffer::from_vec(buf);
363
364        let mut b = buffer.slice_with_length(0, buffer.len() - 1);
365        let mut decoder = StreamDecoder::new();
366        let output = decoder.decode(&mut b).unwrap();
367        assert!(output.is_none());
368        let decoded_schema = decoder.schema().unwrap();
369        assert_eq!(schema, decoded_schema);
370
371        let err = decoder.finish().unwrap_err().to_string();
372        assert_eq!(err, "Ipc error: Unexpected End of Stream");
373    }
374
375    #[test]
376    fn test_read_ree_dict_record_batches_from_buffer() {
377        let schema = Schema::new(vec![Field::new(
378            "test1",
379            DataType::RunEndEncoded(
380                Arc::new(Field::new("run_ends".to_string(), DataType::Int32, false)),
381                #[allow(deprecated)]
382                Arc::new(Field::new_dict(
383                    "values".to_string(),
384                    DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
385                    true,
386                    0,
387                    false,
388                )),
389            ),
390            true,
391        )]);
392        let batch = RecordBatch::try_new(
393            schema.clone().into(),
394            vec![Arc::new(
395                RunArray::try_new(
396                    &Int32Array::from(vec![1, 2, 3]),
397                    &vec![Some("a"), None, Some("a")]
398                        .into_iter()
399                        .collect::<DictionaryArray<Int32Type>>(),
400                )
401                .expect("Failed to create RunArray"),
402            )],
403        )
404        .expect("Failed to create RecordBatch");
405
406        let mut buffer = vec![];
407        {
408            let mut writer = StreamWriter::try_new_with_options(
409                &mut buffer,
410                &schema,
411                IpcWriteOptions::default(),
412            )
413            .expect("Failed to create StreamWriter");
414            writer.write(&batch).expect("Failed to write RecordBatch");
415            writer.finish().expect("Failed to finish StreamWriter");
416        }
417
418        let mut decoder = StreamDecoder::new();
419        let buf = &mut Buffer::from(buffer.as_slice());
420        while let Some(batch) = decoder
421            .decode(buf)
422            .map_err(|e| {
423                ArrowError::ExternalError(format!("Failed to decode record batch: {e}").into())
424            })
425            .expect("Failed to decode record batch")
426        {
427            assert_eq!(batch, batch);
428        }
429
430        decoder.finish().expect("Failed to finish decoder");
431    }
432}