use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use arrow_array::{ArrayRef, RecordBatch};
use arrow_buffer::{Buffer, MutableBuffer};
use arrow_data::UnsafeFlag;
use arrow_schema::{ArrowError, SchemaRef};
use crate::convert::MessageBuffer;
use crate::reader::{RecordBatchDecoder, read_dictionary_impl};
use crate::{CONTINUATION_MARKER, MessageHeader};
#[derive(Debug, Default)]
pub struct StreamDecoder {
schema: Option<SchemaRef>,
dictionaries: HashMap<i64, ArrayRef>,
state: DecoderState,
buf: MutableBuffer,
require_alignment: bool,
skip_validation: UnsafeFlag,
}
#[derive(Debug)]
enum DecoderState {
Header {
buf: [u8; 4],
read: u8,
continuation: bool,
},
Message {
size: u32,
},
Body {
message: MessageBuffer,
},
Finished,
}
impl Default for DecoderState {
fn default() -> Self {
Self::Header {
buf: [0; 4],
read: 0,
continuation: false,
}
}
}
impl StreamDecoder {
pub fn new() -> Self {
Self::default()
}
pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
self.require_alignment = require_alignment;
self
}
pub fn schema(&self) -> Option<SchemaRef> {
self.schema.as_ref().map(|schema| schema.clone())
}
pub fn decode(&mut self, buffer: &mut Buffer) -> Result<Option<RecordBatch>, ArrowError> {
while !buffer.is_empty() {
match &mut self.state {
DecoderState::Header {
buf,
read,
continuation,
} => {
let offset_buf = &mut buf[*read as usize..];
let to_read = buffer.len().min(offset_buf.len());
offset_buf[..to_read].copy_from_slice(&buffer[..to_read]);
*read += to_read as u8;
buffer.advance(to_read);
if *read == 4 {
if !*continuation && buf == &CONTINUATION_MARKER {
*continuation = true;
*read = 0;
continue;
}
let size = u32::from_le_bytes(*buf);
if size == 0 {
self.state = DecoderState::Finished;
continue;
}
self.state = DecoderState::Message { size };
}
}
DecoderState::Message { size } => {
let len = *size as usize;
if self.buf.is_empty() && buffer.len() > len {
let message = MessageBuffer::try_new(buffer.slice_with_length(0, len))?;
self.state = DecoderState::Body { message };
buffer.advance(len);
continue;
}
let to_read = buffer.len().min(len - self.buf.len());
self.buf.extend_from_slice(&buffer[..to_read]);
buffer.advance(to_read);
if self.buf.len() == len {
let message = MessageBuffer::try_new(std::mem::take(&mut self.buf).into())?;
self.state = DecoderState::Body { message };
}
}
DecoderState::Body { message } => {
let message = message.as_ref();
let body_length = message.bodyLength() as usize;
let body = if self.buf.is_empty() && buffer.len() >= body_length {
let body = buffer.slice_with_length(0, body_length);
buffer.advance(body_length);
body
} else {
let to_read = buffer.len().min(body_length - self.buf.len());
self.buf.extend_from_slice(&buffer[..to_read]);
buffer.advance(to_read);
if self.buf.len() != body_length {
continue;
}
std::mem::take(&mut self.buf).into()
};
let version = message.version();
match message.header_type() {
MessageHeader::Schema => {
if self.schema.is_some() {
return Err(ArrowError::IpcError(
"Not expecting a schema when messages are read".to_string(),
));
}
let ipc_schema = message.header_as_schema().unwrap();
let schema = crate::convert::fb_to_schema(ipc_schema);
self.state = DecoderState::default();
self.schema = Some(Arc::new(schema));
}
MessageHeader::RecordBatch => {
let batch = message.header_as_record_batch().unwrap();
let schema = self.schema.clone().ok_or_else(|| {
ArrowError::IpcError("Missing schema".to_string())
})?;
let batch = RecordBatchDecoder::try_new(
&body,
batch,
schema,
&self.dictionaries,
&version,
)?
.with_require_alignment(self.require_alignment)
.read_record_batch()?;
self.state = DecoderState::default();
return Ok(Some(batch));
}
MessageHeader::DictionaryBatch => {
let dictionary = message.header_as_dictionary_batch().unwrap();
let schema = self.schema.as_deref().ok_or_else(|| {
ArrowError::IpcError("Missing schema".to_string())
})?;
read_dictionary_impl(
&body,
dictionary,
schema,
&mut self.dictionaries,
&version,
self.require_alignment,
self.skip_validation.clone(),
)?;
self.state = DecoderState::default();
}
MessageHeader::NONE => {
self.state = DecoderState::default();
}
t => {
return Err(ArrowError::IpcError(format!(
"Message type unsupported by StreamDecoder: {t:?}"
)));
}
}
}
DecoderState::Finished => {
return Err(ArrowError::IpcError("Unexpected EOS".to_string()));
}
}
}
Ok(None)
}
pub fn finish(&mut self) -> Result<(), ArrowError> {
match self.state {
DecoderState::Finished
| DecoderState::Header {
read: 0,
continuation: false,
..
} => Ok(()),
_ => Err(ArrowError::IpcError("Unexpected End of Stream".to_string())),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::writer::{IpcWriteOptions, StreamWriter};
use arrow_array::{
DictionaryArray, Int32Array, Int64Array, RecordBatch, RunArray, types::Int32Type,
};
use arrow_schema::{DataType, Field, Schema};
#[test]
fn test_eos() {
let schema = Arc::new(Schema::new(vec![
Field::new("int32", DataType::Int32, false),
Field::new("int64", DataType::Int64, false),
]));
let input = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])) as _,
Arc::new(Int64Array::from(vec![1, 2, 3])) as _,
],
)
.unwrap();
let mut buf = Vec::with_capacity(1024);
let mut s = StreamWriter::try_new(&mut buf, &schema).unwrap();
s.write(&input).unwrap();
s.finish().unwrap();
drop(s);
let buffer = Buffer::from_vec(buf);
let mut b = buffer.slice_with_length(0, buffer.len() - 1);
let mut decoder = StreamDecoder::new();
let output = decoder.decode(&mut b).unwrap().unwrap();
assert_eq!(output, input);
assert_eq!(b.len(), 7); assert!(decoder.decode(&mut b).unwrap().is_none());
let err = decoder.finish().unwrap_err().to_string();
assert_eq!(err, "Ipc error: Unexpected End of Stream");
}
#[test]
fn test_schema() {
let schema = Arc::new(Schema::new(vec![
Field::new("int32", DataType::Int32, false),
Field::new("int64", DataType::Int64, false),
]));
let mut buf = Vec::with_capacity(1024);
let mut s = StreamWriter::try_new(&mut buf, &schema).unwrap();
s.finish().unwrap();
drop(s);
let buffer = Buffer::from_vec(buf);
let mut b = buffer.slice_with_length(0, buffer.len() - 1);
let mut decoder = StreamDecoder::new();
let output = decoder.decode(&mut b).unwrap();
assert!(output.is_none());
let decoded_schema = decoder.schema().unwrap();
assert_eq!(schema, decoded_schema);
let err = decoder.finish().unwrap_err().to_string();
assert_eq!(err, "Ipc error: Unexpected End of Stream");
}
#[test]
fn test_read_ree_dict_record_batches_from_buffer() {
let schema = Schema::new(vec![Field::new(
"test1",
DataType::RunEndEncoded(
Arc::new(Field::new("run_ends".to_string(), DataType::Int32, false)),
#[allow(deprecated)]
Arc::new(Field::new_dict(
"values".to_string(),
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
true,
0,
false,
)),
),
true,
)]);
let batch = RecordBatch::try_new(
schema.clone().into(),
vec![Arc::new(
RunArray::try_new(
&Int32Array::from(vec![1, 2, 3]),
&vec![Some("a"), None, Some("a")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>(),
)
.expect("Failed to create RunArray"),
)],
)
.expect("Failed to create RecordBatch");
let mut buffer = vec![];
{
let mut writer = StreamWriter::try_new_with_options(
&mut buffer,
&schema,
IpcWriteOptions::default(),
)
.expect("Failed to create StreamWriter");
writer.write(&batch).expect("Failed to write RecordBatch");
writer.finish().expect("Failed to finish StreamWriter");
}
let mut decoder = StreamDecoder::new();
let buf = &mut Buffer::from(buffer.as_slice());
while let Some(batch) = decoder
.decode(buf)
.map_err(|e| {
ArrowError::ExternalError(format!("Failed to decode record batch: {e}").into())
})
.expect("Failed to decode record batch")
{
assert_eq!(batch, batch);
}
decoder.finish().expect("Failed to finish decoder");
}
}