1use std::io::{Read, Write};
23use std::sync::Arc;
24
25use arrow_array::RecordBatch;
26use arrow_buffer::Buffer;
27use arrow_ipc::convert::fb_to_schema;
28use arrow_ipc::reader::FileDecoder;
29use arrow_ipc::root_as_message;
30use arrow_ipc::writer::StreamWriter;
31use arrow_schema::ArrowError;
32use bytes::Bytes;
33
34pub fn write_len_prefixed_bytes(writer: &mut dyn Write, data: &[u8]) -> Result<(), ArrowError> {
42 writer
43 .write_all(&(data.len() as u64).to_le_bytes())
44 .map_err(|e| ArrowError::IoError(e.to_string(), e))?;
45 writer
46 .write_all(data)
47 .map_err(|e| ArrowError::IoError(e.to_string(), e))
48}
49
50pub fn read_len_prefixed_bytes(reader: &mut dyn Read) -> Result<Vec<u8>, ArrowError> {
54 let mut len_buf = [0u8; 8];
55 reader
56 .read_exact(&mut len_buf)
57 .map_err(|e| ArrowError::IoError(e.to_string(), e))?;
58 let len = u64::from_le_bytes(len_buf) as usize;
59 let mut buf = vec![0u8; len];
60 reader
61 .read_exact(&mut buf)
62 .map_err(|e| ArrowError::IoError(e.to_string(), e))?;
63 Ok(buf)
64}
65
66const IPC_CONTINUATION: [u8; 4] = [0xff; 4];
72
73pub fn write_ipc_stream(batch: &RecordBatch, writer: &mut dyn Write) -> Result<(), ArrowError> {
75 let mut sw = StreamWriter::try_new(&mut *writer, batch.schema_ref())?;
76 sw.write(batch)?;
77 sw.finish()
78}
79
80pub fn write_ipc_stream_batches<I>(iter: I, writer: &mut dyn Write) -> Result<(), ArrowError>
87where
88 I: IntoIterator<Item = RecordBatch>,
89{
90 let mut iter = iter.into_iter();
91 let first = iter
92 .next()
93 .ok_or_else(|| ArrowError::InvalidArgumentError("no batches to serialize".into()))?;
94 let mut sw = StreamWriter::try_new(&mut *writer, first.schema_ref())?;
95 sw.write(&first)?;
96 for batch in iter {
97 sw.write(&batch)?;
98 }
99 sw.finish()
100}
101
102fn read_one_ipc_message(data: &Bytes) -> Result<Option<Buffer>, ArrowError> {
110 let bytes = data.as_ref();
111
112 if bytes.is_empty() {
113 return Ok(None);
114 }
115 if bytes.len() < 4 {
116 return Err(ArrowError::IoError(
117 "IPC: truncated header".into(),
118 std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "truncated IPC header"),
119 ));
120 }
121
122 let has_continuation = bytes[..4] == IPC_CONTINUATION;
123 let (size_bytes, prefix_len): ([u8; 4], usize) = if has_continuation {
124 if bytes.len() < 8 {
125 return Err(ArrowError::IoError(
126 "IPC: truncated header after continuation".into(),
127 std::io::Error::new(
128 std::io::ErrorKind::UnexpectedEof,
129 "truncated after continuation",
130 ),
131 ));
132 }
133 (bytes[4..8].try_into().unwrap(), 8)
134 } else {
135 (bytes[..4].try_into().unwrap(), 4)
136 };
137
138 let meta_size = u32::from_le_bytes(size_bytes) as usize;
139 if meta_size == 0 {
140 return Ok(None); }
142
143 let meta_end = prefix_len + meta_size;
144 if bytes.len() < meta_end {
145 return Err(ArrowError::IoError(
146 "IPC: truncated metadata".into(),
147 std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "truncated IPC metadata"),
148 ));
149 }
150
151 let msg = root_as_message(&bytes[prefix_len..meta_end])
152 .map_err(|e| ArrowError::ParseError(format!("IPC message parse error: {e}")))?;
153 let body_len = msg.bodyLength() as usize;
154
155 let total = meta_end + body_len;
156 if bytes.len() < total {
157 return Err(ArrowError::IoError(
158 "IPC: truncated body".into(),
159 std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "truncated IPC body"),
160 ));
161 }
162
163 Ok(Some(Buffer::from(data.slice(0..total))))
166}
167
168pub fn read_len_prefixed_bytes_at(data: &Bytes, offset: &mut usize) -> Result<Bytes, ArrowError> {
173 let bytes = data.as_ref();
174 let len_end = offset
175 .checked_add(8)
176 .filter(|&e| e <= bytes.len())
177 .ok_or_else(|| {
178 ArrowError::IoError(
179 "length-prefixed bytes: truncated length field".into(),
180 std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "truncated length"),
181 )
182 })?;
183 let len = u64::from_le_bytes(bytes[*offset..len_end].try_into().unwrap()) as usize;
184 *offset = len_end;
185 let data_end = offset
186 .checked_add(len)
187 .filter(|&e| e <= bytes.len())
188 .ok_or_else(|| {
189 ArrowError::IoError(
190 "length-prefixed bytes: truncated data".into(),
191 std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "truncated data"),
192 )
193 })?;
194 let result = data.slice(*offset..data_end);
195 *offset = data_end;
196 Ok(result)
197}
198
199pub fn read_ipc_stream_at(
204 data: &Bytes,
205 offset: &mut usize,
206) -> Result<Vec<RecordBatch>, ArrowError> {
207 let batches = read_ipc_stream(&data.slice(*offset..))?;
208
209 let slice = &data.as_ref()[*offset..];
213 let mut consumed = 0usize;
214 loop {
215 let rem = &slice[consumed..];
216 if rem.is_empty() {
217 break;
218 }
219 let has_cont = rem.len() >= 4 && rem[..4] == IPC_CONTINUATION;
220 let (size_bytes, prefix_len): ([u8; 4], usize) = if has_cont {
221 if rem.len() < 8 {
222 break;
223 }
224 (rem[4..8].try_into().unwrap(), 8)
225 } else {
226 if rem.len() < 4 {
227 break;
228 }
229 (rem[..4].try_into().unwrap(), 4)
230 };
231 let meta_size = u32::from_le_bytes(size_bytes) as usize;
232 if meta_size == 0 {
233 consumed += prefix_len;
235 break;
236 }
237 let meta_end = prefix_len + meta_size;
238 if rem.len() < meta_end {
239 break;
240 }
241 let msg = root_as_message(&rem[prefix_len..meta_end])
242 .map_err(|e| ArrowError::ParseError(format!("IPC message parse error: {e}")))?;
243 let body_len = msg.bodyLength() as usize;
244 consumed += meta_end + body_len;
245 }
246 *offset += consumed;
247
248 Ok(batches)
249}
250
251pub fn read_ipc_stream_single_at(
256 data: &Bytes,
257 offset: &mut usize,
258) -> Result<RecordBatch, ArrowError> {
259 let mut batches = read_ipc_stream_at(data, offset)?;
260 match batches.len() {
261 1 => Ok(batches.remove(0)),
262 n => Err(ArrowError::ParseError(format!(
263 "expected exactly 1 IPC record batch, got {n}"
264 ))),
265 }
266}
267
268fn parse_ipc_message_prefix(buf: &Buffer) -> Result<(usize, usize), ArrowError> {
273 let has_continuation = buf.len() >= 4 && buf[..4] == [0xff; 4];
274 if has_continuation {
275 if buf.len() < 8 {
276 return Err(ArrowError::ParseError(
277 "IPC message buffer too short".into(),
278 ));
279 }
280 let meta_size = u32::from_le_bytes(buf[4..8].try_into().unwrap()) as usize;
281 Ok((8, meta_size))
282 } else {
283 if buf.len() < 4 {
284 return Err(ArrowError::ParseError(
285 "IPC message buffer too short".into(),
286 ));
287 }
288 let meta_size = u32::from_le_bytes(buf[..4].try_into().unwrap()) as usize;
289 Ok((4, meta_size))
290 }
291}
292
293pub fn read_ipc_stream(data: &Bytes) -> Result<Vec<RecordBatch>, ArrowError> {
302 let mut offset = 0usize;
303
304 let schema_buf = read_one_ipc_message(&data.slice(offset..))?.ok_or_else(|| {
305 ArrowError::ParseError("IPC stream: expected schema message, got EOS".into())
306 })?;
307 offset += schema_buf.len();
308
309 let (prefix_len, meta_size) = parse_ipc_message_prefix(&schema_buf)?;
310 let schema_msg = root_as_message(&schema_buf[prefix_len..prefix_len + meta_size])
311 .map_err(|e| ArrowError::ParseError(format!("IPC schema parse error: {e}")))?;
312 let schema = Arc::new(fb_to_schema(schema_msg.header_as_schema().ok_or_else(
313 || ArrowError::ParseError("IPC stream: first message is not a schema".into()),
314 )?));
315 let mut decoder = FileDecoder::new(schema, schema_msg.version());
316
317 let mut batches = Vec::new();
318
319 loop {
320 let Some(buf) = read_one_ipc_message(&data.slice(offset..))? else {
321 break;
322 };
323 offset += buf.len();
324
325 let (prefix_len, meta_size) = parse_ipc_message_prefix(&buf)?;
326 let msg = root_as_message(&buf[prefix_len..prefix_len + meta_size])
327 .map_err(|e| ArrowError::ParseError(format!("IPC message parse error: {e}")))?;
328 let body_len = msg.bodyLength() as usize;
329
330 let block = arrow_ipc::Block::new(0, (prefix_len + meta_size) as i32, body_len as i64);
333
334 match msg.header_type() {
335 arrow_ipc::MessageHeader::RecordBatch => {
336 if let Some(batch) = decoder.read_record_batch(&block, &buf)? {
337 batches.push(batch);
338 }
339 }
340 arrow_ipc::MessageHeader::DictionaryBatch => {
341 decoder.read_dictionary(&block, &buf)?;
342 }
343 _ => break,
344 }
345 }
346
347 Ok(batches)
348}
349
350pub fn read_ipc_stream_single(data: &Bytes) -> Result<RecordBatch, ArrowError> {
352 let mut batches = read_ipc_stream(data)?;
353 match batches.len() {
354 1 => Ok(batches.remove(0)),
355 n => Err(ArrowError::ParseError(format!(
356 "expected exactly 1 IPC record batch, got {n}"
357 ))),
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use arrow_array::{ArrayRef, record_batch};
364
365 use super::*;
366
367 #[test]
368 fn test_ipc_roundtrip() {
369 let batch1 = record_batch!(
370 ("int", Int32, [1, 2, 3]),
371 ("str", Utf8, ["foo", "bar", "baz"])
372 )
373 .unwrap();
374 let batch2 = record_batch!(("int", Int32, [4, 5]), ("str", Utf8, ["qux", "quux"])).unwrap();
375 let batches = vec![batch1.clone(), batch2.clone()];
376
377 let mut buf = Vec::new();
378 write_ipc_stream_batches(batches, &mut buf).unwrap();
379
380 let data = Bytes::from(buf);
381
382 let batches = read_ipc_stream(&data).unwrap();
383 assert_eq!(batches.len(), 2);
384 assert_eq!(batches[0], batch1);
385 assert_eq!(batches[1], batch2);
386
387 let data_base = data.as_ptr() as usize;
388 let data_end = data_base + data.len();
389 let assert_col_zero_copy = |array: &ArrayRef| {
390 for buffer in array.to_data().buffers() {
391 let ptr = buffer.as_ptr() as usize;
392 assert!(
393 ptr >= data_base && ptr < data_end,
394 "buffer at {ptr:#x} is not backed by the input Bytes allocation \
395 [{data_base:#x}..{data_end:#x})"
396 );
397 }
398 };
399
400 for batch in &batches {
401 assert_eq!(batch.schema(), batch1.schema());
402 assert_col_zero_copy(batch.column(0));
403 assert_col_zero_copy(batch.column(1));
404 }
405 }
406}