1use std::collections::HashMap;
19use std::fmt::Debug;
20use std::sync::Arc;
21
22use arrow_array::{ArrayRef, RecordBatch};
23use arrow_buffer::{Buffer, MutableBuffer};
24use arrow_data::UnsafeFlag;
25use arrow_schema::{ArrowError, SchemaRef};
26
27use crate::convert::MessageBuffer;
28use crate::reader::{RecordBatchDecoder, read_dictionary_impl};
29use crate::{CONTINUATION_MARKER, MessageHeader};
30
31#[derive(Debug, Default)]
35pub struct StreamDecoder {
36 schema: Option<SchemaRef>,
38 dictionaries: HashMap<i64, ArrayRef>,
40 state: DecoderState,
42 buf: MutableBuffer,
44 require_alignment: bool,
46 skip_validation: UnsafeFlag,
51}
52
53#[derive(Debug)]
54enum DecoderState {
55 Header {
57 buf: [u8; 4],
59 read: u8,
61 continuation: bool,
63 },
64 Message {
66 size: u32,
68 },
69 Body {
71 message: MessageBuffer,
73 },
74 Finished,
76}
77
78impl Default for DecoderState {
79 fn default() -> Self {
80 Self::Header {
81 buf: [0; 4],
82 read: 0,
83 continuation: false,
84 }
85 }
86}
87
88impl StreamDecoder {
89 pub fn new() -> Self {
91 Self::default()
92 }
93
94 pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
107 self.require_alignment = require_alignment;
108 self
109 }
110
111 pub fn schema(&self) -> Option<SchemaRef> {
113 self.schema.as_ref().map(|schema| schema.clone())
114 }
115
116 pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
126 unsafe { self.skip_validation.set(skip_validation) };
127 self
128 }
129
130 pub fn decode(&mut self, buffer: &mut Buffer) -> Result<Option<RecordBatch>, ArrowError> {
160 while !buffer.is_empty() {
161 match &mut self.state {
162 DecoderState::Header {
163 buf,
164 read,
165 continuation,
166 } => {
167 let offset_buf = &mut buf[*read as usize..];
168 let to_read = buffer.len().min(offset_buf.len());
169 offset_buf[..to_read].copy_from_slice(&buffer[..to_read]);
170 *read += to_read as u8;
171 buffer.advance(to_read);
172 if *read == 4 {
173 if !*continuation && buf == &CONTINUATION_MARKER {
174 *continuation = true;
175 *read = 0;
176 continue;
177 }
178 let size = u32::from_le_bytes(*buf);
179
180 if size == 0 {
181 self.state = DecoderState::Finished;
182 continue;
183 }
184 self.state = DecoderState::Message { size };
185 }
186 }
187 DecoderState::Message { size } => {
188 let len = *size as usize;
189 if self.buf.is_empty() && buffer.len() > len {
190 let message = MessageBuffer::try_new(buffer.slice_with_length(0, len))?;
191 self.state = DecoderState::Body { message };
192 buffer.advance(len);
193 continue;
194 }
195
196 let to_read = buffer.len().min(len - self.buf.len());
197 self.buf.extend_from_slice(&buffer[..to_read]);
198 buffer.advance(to_read);
199 if self.buf.len() == len {
200 let message = MessageBuffer::try_new(std::mem::take(&mut self.buf).into())?;
201 self.state = DecoderState::Body { message };
202 }
203 }
204 DecoderState::Body { message } => {
205 let message = message.as_ref();
206 let body_length = message.bodyLength() as usize;
207
208 let body = if self.buf.is_empty() && buffer.len() >= body_length {
209 let body = buffer.slice_with_length(0, body_length);
210 buffer.advance(body_length);
211 body
212 } else {
213 let to_read = buffer.len().min(body_length - self.buf.len());
214 self.buf.extend_from_slice(&buffer[..to_read]);
215 buffer.advance(to_read);
216
217 if self.buf.len() != body_length {
218 continue;
219 }
220 std::mem::take(&mut self.buf).into()
221 };
222
223 let version = message.version();
224 match message.header_type() {
225 MessageHeader::Schema => {
226 if self.schema.is_some() {
227 return Err(ArrowError::IpcError(
228 "Not expecting a schema when messages are read".to_string(),
229 ));
230 }
231
232 let ipc_schema = message.header_as_schema().unwrap();
233 let schema = crate::convert::fb_to_schema(ipc_schema);
234 self.state = DecoderState::default();
235 self.schema = Some(Arc::new(schema));
236 }
237 MessageHeader::RecordBatch => {
238 let batch = message.header_as_record_batch().unwrap();
239 let schema = self.schema.clone().ok_or_else(|| {
240 ArrowError::IpcError("Missing schema".to_string())
241 })?;
242 let batch = RecordBatchDecoder::try_new(
243 &body,
244 batch,
245 schema,
246 &self.dictionaries,
247 &version,
248 )?
249 .with_require_alignment(self.require_alignment)
250 .read_record_batch()?;
251 self.state = DecoderState::default();
252 return Ok(Some(batch));
253 }
254 MessageHeader::DictionaryBatch => {
255 let dictionary = message.header_as_dictionary_batch().unwrap();
256 let schema = self.schema.as_deref().ok_or_else(|| {
257 ArrowError::IpcError("Missing schema".to_string())
258 })?;
259 read_dictionary_impl(
260 &body,
261 dictionary,
262 schema,
263 &mut self.dictionaries,
264 &version,
265 self.require_alignment,
266 self.skip_validation.clone(),
267 )?;
268 self.state = DecoderState::default();
269 }
270 MessageHeader::NONE => {
271 self.state = DecoderState::default();
272 }
273 t => {
274 return Err(ArrowError::IpcError(format!(
275 "Message type unsupported by StreamDecoder: {t:?}"
276 )));
277 }
278 }
279 }
280 DecoderState::Finished => {
281 return Err(ArrowError::IpcError("Unexpected EOS".to_string()));
282 }
283 }
284 }
285 Ok(None)
286 }
287
288 pub fn finish(&mut self) -> Result<(), ArrowError> {
292 match self.state {
293 DecoderState::Finished
294 | DecoderState::Header {
295 read: 0,
296 continuation: false,
297 ..
298 } => Ok(()),
299 _ => Err(ArrowError::IpcError("Unexpected End of Stream".to_string())),
300 }
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use crate::writer::{IpcWriteOptions, StreamWriter};
308 use arrow_array::{
309 DictionaryArray, Int32Array, Int64Array, RecordBatch, RunArray, types::Int32Type,
310 };
311 use arrow_schema::{DataType, Field, Schema};
312
313 #[test]
316 fn test_eos() {
317 let schema = Arc::new(Schema::new(vec![
318 Field::new("int32", DataType::Int32, false),
319 Field::new("int64", DataType::Int64, false),
320 ]));
321
322 let input = RecordBatch::try_new(
323 schema.clone(),
324 vec![
325 Arc::new(Int32Array::from(vec![1, 2, 3])) as _,
326 Arc::new(Int64Array::from(vec![1, 2, 3])) as _,
327 ],
328 )
329 .unwrap();
330
331 let mut buf = Vec::with_capacity(1024);
332 let mut s = StreamWriter::try_new(&mut buf, &schema).unwrap();
333 s.write(&input).unwrap();
334 s.finish().unwrap();
335 drop(s);
336
337 let buffer = Buffer::from_vec(buf);
338
339 let mut b = buffer.slice_with_length(0, buffer.len() - 1);
340 let mut decoder = StreamDecoder::new();
341 let output = decoder.decode(&mut b).unwrap().unwrap();
342 assert_eq!(output, input);
343 assert_eq!(b.len(), 7); assert!(decoder.decode(&mut b).unwrap().is_none());
345
346 let err = decoder.finish().unwrap_err().to_string();
347 assert_eq!(err, "Ipc error: Unexpected End of Stream");
348 }
349
350 #[test]
351 fn test_schema() {
352 let schema = Arc::new(Schema::new(vec![
353 Field::new("int32", DataType::Int32, false),
354 Field::new("int64", DataType::Int64, false),
355 ]));
356
357 let mut buf = Vec::with_capacity(1024);
358 let mut s = StreamWriter::try_new(&mut buf, &schema).unwrap();
359 s.finish().unwrap();
360 drop(s);
361
362 let buffer = Buffer::from_vec(buf);
363
364 let mut b = buffer.slice_with_length(0, buffer.len() - 1);
365 let mut decoder = StreamDecoder::new();
366 let output = decoder.decode(&mut b).unwrap();
367 assert!(output.is_none());
368 let decoded_schema = decoder.schema().unwrap();
369 assert_eq!(schema, decoded_schema);
370
371 let err = decoder.finish().unwrap_err().to_string();
372 assert_eq!(err, "Ipc error: Unexpected End of Stream");
373 }
374
375 #[test]
376 fn test_read_ree_dict_record_batches_from_buffer() {
377 let schema = Schema::new(vec![Field::new(
378 "test1",
379 DataType::RunEndEncoded(
380 Arc::new(Field::new("run_ends".to_string(), DataType::Int32, false)),
381 #[allow(deprecated)]
382 Arc::new(Field::new_dict(
383 "values".to_string(),
384 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
385 true,
386 0,
387 false,
388 )),
389 ),
390 true,
391 )]);
392 let batch = RecordBatch::try_new(
393 schema.clone().into(),
394 vec![Arc::new(
395 RunArray::try_new(
396 &Int32Array::from(vec![1, 2, 3]),
397 &vec![Some("a"), None, Some("a")]
398 .into_iter()
399 .collect::<DictionaryArray<Int32Type>>(),
400 )
401 .expect("Failed to create RunArray"),
402 )],
403 )
404 .expect("Failed to create RecordBatch");
405
406 let mut buffer = vec![];
407 {
408 let mut writer = StreamWriter::try_new_with_options(
409 &mut buffer,
410 &schema,
411 IpcWriteOptions::default(),
412 )
413 .expect("Failed to create StreamWriter");
414 writer.write(&batch).expect("Failed to write RecordBatch");
415 writer.finish().expect("Failed to finish StreamWriter");
416 }
417
418 let mut decoder = StreamDecoder::new();
419 let buf = &mut Buffer::from(buffer.as_slice());
420 while let Some(batch) = decoder
421 .decode(buf)
422 .map_err(|e| {
423 ArrowError::ExternalError(format!("Failed to decode record batch: {e}").into())
424 })
425 .expect("Failed to decode record batch")
426 {
427 assert_eq!(batch, batch);
428 }
429
430 decoder.finish().expect("Failed to finish decoder");
431 }
432}