1use std::io::{Read, Write};
23use std::sync::Arc;
24
25use arrow_array::RecordBatch;
26use arrow_buffer::Buffer;
27use arrow_ipc::convert::fb_to_schema;
28use arrow_ipc::reader::FileDecoder;
29use arrow_ipc::root_as_message;
30use arrow_ipc::writer::StreamWriter;
31use arrow_schema::ArrowError;
32use bytes::Bytes;
33
34pub fn write_len_prefixed_bytes(writer: &mut dyn Write, data: &[u8]) -> Result<(), ArrowError> {
42 writer
43 .write_all(&(data.len() as u64).to_le_bytes())
44 .map_err(|e| ArrowError::IoError(e.to_string(), e))?;
45 writer
46 .write_all(data)
47 .map_err(|e| ArrowError::IoError(e.to_string(), e))
48}
49
50pub fn read_len_prefixed_bytes(reader: &mut dyn Read) -> Result<Vec<u8>, ArrowError> {
54 let mut len_buf = [0u8; 8];
55 reader
56 .read_exact(&mut len_buf)
57 .map_err(|e| ArrowError::IoError(e.to_string(), e))?;
58 let len = u64::from_le_bytes(len_buf) as usize;
59 let mut buf = vec![0u8; len];
60 reader
61 .read_exact(&mut buf)
62 .map_err(|e| ArrowError::IoError(e.to_string(), e))?;
63 Ok(buf)
64}
65
66const IPC_CONTINUATION: [u8; 4] = [0xff; 4];
72
73pub fn write_ipc_stream(batch: &RecordBatch, writer: &mut dyn Write) -> Result<(), ArrowError> {
75 let mut sw = StreamWriter::try_new(&mut *writer, batch.schema_ref())?;
76 sw.write(batch)?;
77 sw.finish()
78}
79
80pub fn write_ipc_stream_batches<I>(iter: I, writer: &mut dyn Write) -> Result<(), ArrowError>
87where
88 I: IntoIterator<Item = RecordBatch>,
89{
90 let mut iter = iter.into_iter();
91 let first = iter
92 .next()
93 .ok_or_else(|| ArrowError::InvalidArgumentError("no batches to serialize".into()))?;
94 let mut sw = StreamWriter::try_new(&mut *writer, first.schema_ref())?;
95 sw.write(&first)?;
96 for batch in iter {
97 sw.write(&batch)?;
98 }
99 sw.finish()
100}
101
102fn read_one_ipc_message(data: &Bytes) -> Result<Option<Buffer>, ArrowError> {
110 let bytes = data.as_ref();
111
112 if bytes.is_empty() {
113 return Ok(None);
114 }
115 if bytes.len() < 4 {
116 return Err(ArrowError::IoError(
117 "IPC: truncated header".into(),
118 std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "truncated IPC header"),
119 ));
120 }
121
122 let has_continuation = bytes[..4] == IPC_CONTINUATION;
123 let (size_bytes, prefix_len): ([u8; 4], usize) = if has_continuation {
124 if bytes.len() < 8 {
125 return Err(ArrowError::IoError(
126 "IPC: truncated header after continuation".into(),
127 std::io::Error::new(
128 std::io::ErrorKind::UnexpectedEof,
129 "truncated after continuation",
130 ),
131 ));
132 }
133 (bytes[4..8].try_into().unwrap(), 8)
134 } else {
135 (bytes[..4].try_into().unwrap(), 4)
136 };
137
138 let meta_size = u32::from_le_bytes(size_bytes) as usize;
139 if meta_size == 0 {
140 return Ok(None); }
142
143 let meta_end = prefix_len + meta_size;
144 if bytes.len() < meta_end {
145 return Err(ArrowError::IoError(
146 "IPC: truncated metadata".into(),
147 std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "truncated IPC metadata"),
148 ));
149 }
150
151 let msg = root_as_message(&bytes[prefix_len..meta_end])
152 .map_err(|e| ArrowError::ParseError(format!("IPC message parse error: {e}")))?;
153 let body_len = msg.bodyLength() as usize;
154
155 let total = meta_end + body_len;
156 if bytes.len() < total {
157 return Err(ArrowError::IoError(
158 "IPC: truncated body".into(),
159 std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "truncated IPC body"),
160 ));
161 }
162
163 Ok(Some(Buffer::from(data.slice(0..total))))
166}
167
168pub fn read_len_prefixed_bytes_at(data: &Bytes, offset: &mut usize) -> Result<Bytes, ArrowError> {
173 let bytes = data.as_ref();
174 let len_end = offset
175 .checked_add(8)
176 .filter(|&e| e <= bytes.len())
177 .ok_or_else(|| {
178 ArrowError::IoError(
179 "length-prefixed bytes: truncated length field".into(),
180 std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "truncated length"),
181 )
182 })?;
183 let len = u64::from_le_bytes(bytes[*offset..len_end].try_into().unwrap()) as usize;
184 *offset = len_end;
185 let data_end = offset
186 .checked_add(len)
187 .filter(|&e| e <= bytes.len())
188 .ok_or_else(|| {
189 ArrowError::IoError(
190 "length-prefixed bytes: truncated data".into(),
191 std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "truncated data"),
192 )
193 })?;
194 let result = data.slice(*offset..data_end);
195 *offset = data_end;
196 Ok(result)
197}
198
199pub fn read_ipc_stream_at(
204 data: &Bytes,
205 offset: &mut usize,
206) -> Result<Vec<RecordBatch>, ArrowError> {
207 let batches = read_ipc_stream(&data.slice(*offset..))?;
208
209 let slice = &data.as_ref()[*offset..];
213 let mut consumed = 0usize;
214 loop {
215 let rem = &slice[consumed..];
216 if rem.is_empty() {
217 break;
218 }
219 let has_cont = rem.len() >= 4 && rem[..4] == IPC_CONTINUATION;
220 let (size_bytes, prefix_len): ([u8; 4], usize) = if has_cont {
221 if rem.len() < 8 {
222 break;
223 }
224 (rem[4..8].try_into().unwrap(), 8)
225 } else {
226 if rem.len() < 4 {
227 break;
228 }
229 (rem[..4].try_into().unwrap(), 4)
230 };
231 let meta_size = u32::from_le_bytes(size_bytes) as usize;
232 if meta_size == 0 {
233 consumed += prefix_len;
235 break;
236 }
237 let meta_end = prefix_len + meta_size;
238 if rem.len() < meta_end {
239 break;
240 }
241 let msg = root_as_message(&rem[prefix_len..meta_end])
242 .map_err(|e| ArrowError::ParseError(format!("IPC message parse error: {e}")))?;
243 let body_len = msg.bodyLength() as usize;
244 consumed += meta_end + body_len;
245 }
246 *offset += consumed;
247
248 Ok(batches)
249}
250
251pub fn read_ipc_stream_single_at(
256 data: &Bytes,
257 offset: &mut usize,
258) -> Result<RecordBatch, ArrowError> {
259 let mut batches = read_ipc_stream_at(data, offset)?;
260 match batches.len() {
261 1 => Ok(batches.remove(0)),
262 n => Err(ArrowError::ParseError(format!(
263 "expected exactly 1 IPC record batch, got {n}"
264 ))),
265 }
266}
267
268fn parse_ipc_message_prefix(buf: &Buffer) -> Result<(usize, usize), ArrowError> {
273 let has_continuation = buf.len() >= 4 && buf[..4] == IPC_CONTINUATION;
274 if has_continuation {
275 if buf.len() < 8 {
276 return Err(ArrowError::ParseError(
277 "IPC message buffer too short".into(),
278 ));
279 }
280 let meta_size = u32::from_le_bytes(buf[4..8].try_into().unwrap()) as usize;
281 Ok((8, meta_size))
282 } else {
283 if buf.len() < 4 {
284 return Err(ArrowError::ParseError(
285 "IPC message buffer too short".into(),
286 ));
287 }
288 let meta_size = u32::from_le_bytes(buf[..4].try_into().unwrap()) as usize;
289 Ok((4, meta_size))
290 }
291}
292
293pub fn read_ipc_stream(data: &Bytes) -> Result<Vec<RecordBatch>, ArrowError> {
302 let mut offset = 0usize;
303
304 let schema_buf = read_one_ipc_message(&data.slice(offset..))?.ok_or_else(|| {
305 ArrowError::ParseError("IPC stream: expected schema message, got EOS".into())
306 })?;
307 offset += schema_buf.len();
308
309 let (prefix_len, meta_size) = parse_ipc_message_prefix(&schema_buf)?;
310 let schema_msg = root_as_message(&schema_buf[prefix_len..prefix_len + meta_size])
311 .map_err(|e| ArrowError::ParseError(format!("IPC schema parse error: {e}")))?;
312 let schema = Arc::new(fb_to_schema(schema_msg.header_as_schema().ok_or_else(
313 || ArrowError::ParseError("IPC stream: first message is not a schema".into()),
314 )?));
315 let mut decoder = FileDecoder::new(schema, schema_msg.version());
316
317 let mut batches = Vec::new();
318
319 loop {
320 let Some(buf) = read_one_ipc_message(&data.slice(offset..))? else {
321 break;
322 };
323 offset += buf.len();
324
325 let (prefix_len, meta_size) = parse_ipc_message_prefix(&buf)?;
326 let msg = root_as_message(&buf[prefix_len..prefix_len + meta_size])
327 .map_err(|e| ArrowError::ParseError(format!("IPC message parse error: {e}")))?;
328 let body_len = msg.bodyLength() as usize;
329
330 let block = arrow_ipc::Block::new(0, (prefix_len + meta_size) as i32, body_len as i64);
333
334 match msg.header_type() {
335 arrow_ipc::MessageHeader::RecordBatch => {
336 if let Some(batch) = decoder.read_record_batch(&block, &buf)? {
337 batches.push(batch);
338 }
339 }
340 arrow_ipc::MessageHeader::DictionaryBatch => {
341 decoder.read_dictionary(&block, &buf)?;
342 }
343 _ => break,
344 }
345 }
346
347 Ok(batches)
348}
349
350pub fn read_ipc_stream_single(data: &Bytes) -> Result<RecordBatch, ArrowError> {
352 let mut batches = read_ipc_stream(data)?;
353 match batches.len() {
354 1 => Ok(batches.remove(0)),
355 n => Err(ArrowError::ParseError(format!(
356 "expected exactly 1 IPC record batch, got {n}"
357 ))),
358 }
359}
360
361pub const IPC_SECTION_ALIGNMENT: usize = 64;
374
375fn section_padding(pos: usize) -> usize {
378 (IPC_SECTION_ALIGNMENT - (pos % IPC_SECTION_ALIGNMENT)) % IPC_SECTION_ALIGNMENT
379}
380
381struct CountingWriter<'a> {
383 inner: &'a mut dyn Write,
384 count: usize,
385}
386
387impl Write for CountingWriter<'_> {
388 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
389 let n = self.inner.write(buf)?;
390 self.count += n;
391 Ok(n)
392 }
393
394 fn flush(&mut self) -> std::io::Result<()> {
395 self.inner.flush()
396 }
397}
398
399fn write_section_padding(writer: &mut dyn Write, pos: &mut usize) -> Result<(), ArrowError> {
402 let pad = section_padding(*pos);
403 if pad > 0 {
404 const ZEROS: [u8; IPC_SECTION_ALIGNMENT] = [0u8; IPC_SECTION_ALIGNMENT];
405 writer
406 .write_all(&ZEROS[..pad])
407 .map_err(|e| ArrowError::IoError(e.to_string(), e))?;
408 *pos += pad;
409 }
410 Ok(())
411}
412
413pub fn write_ipc_section(
425 writer: &mut dyn Write,
426 pos: &mut usize,
427 batch: &RecordBatch,
428) -> Result<(), ArrowError> {
429 write_section_padding(writer, pos)?;
430
431 let mut counting = CountingWriter {
432 inner: writer,
433 count: 0,
434 };
435 write_ipc_stream(batch, &mut counting)?;
436 *pos += counting.count;
437 Ok(())
438}
439
440pub fn read_ipc_section_at(data: &Bytes, offset: &mut usize) -> Result<RecordBatch, ArrowError> {
448 *offset += section_padding(*offset);
449 read_ipc_stream_single_at(data, offset)
450}
451
452pub fn write_ipc_section_batches<I>(
458 writer: &mut dyn Write,
459 pos: &mut usize,
460 iter: I,
461) -> Result<(), ArrowError>
462where
463 I: IntoIterator<Item = RecordBatch>,
464{
465 write_section_padding(writer, pos)?;
466
467 let mut counting = CountingWriter {
468 inner: writer,
469 count: 0,
470 };
471 write_ipc_stream_batches(iter, &mut counting)?;
472 *pos += counting.count;
473 Ok(())
474}
475
476pub fn read_ipc_section_batches_at(
482 data: &Bytes,
483 offset: &mut usize,
484) -> Result<Vec<RecordBatch>, ArrowError> {
485 *offset += section_padding(*offset);
486 read_ipc_stream_at(data, offset)
487}
488
489#[cfg(test)]
490mod tests {
491 use arrow_array::{ArrayRef, record_batch};
492
493 use super::*;
494
495 #[test]
496 fn test_ipc_roundtrip() {
497 let batch1 = record_batch!(
498 ("int", Int32, [1, 2, 3]),
499 ("str", Utf8, ["foo", "bar", "baz"])
500 )
501 .unwrap();
502 let batch2 = record_batch!(("int", Int32, [4, 5]), ("str", Utf8, ["qux", "quux"])).unwrap();
503 let batches = vec![batch1.clone(), batch2.clone()];
504
505 let mut buf = Vec::new();
506 write_ipc_stream_batches(batches, &mut buf).unwrap();
507
508 let data = Bytes::from(buf);
509
510 let batches = read_ipc_stream(&data).unwrap();
511 assert_eq!(batches.len(), 2);
512 assert_eq!(batches[0], batch1);
513 assert_eq!(batches[1], batch2);
514
515 let data_base = data.as_ptr() as usize;
516 let data_end = data_base + data.len();
517 let assert_col_zero_copy = |array: &ArrayRef| {
518 for buffer in array.to_data().buffers() {
519 let ptr = buffer.as_ptr() as usize;
520 assert!(
521 ptr >= data_base && ptr < data_end,
522 "buffer at {ptr:#x} is not backed by the input Bytes allocation \
523 [{data_base:#x}..{data_end:#x})"
524 );
525 }
526 };
527
528 for batch in &batches {
529 assert_eq!(batch.schema(), batch1.schema());
530 assert_col_zero_copy(batch.column(0));
531 assert_col_zero_copy(batch.column(1));
532 }
533 }
534
535 fn aligned_bytes(payload: &[u8]) -> Bytes {
539 let mut v = vec![0u8; payload.len() + IPC_SECTION_ALIGNMENT];
540 let pad = section_padding(v.as_ptr() as usize);
541 v[pad..pad + payload.len()].copy_from_slice(payload);
542 Bytes::from(v).slice(pad..pad + payload.len())
543 }
544
545 #[test]
546 fn test_aligned_ipc_sections_are_zero_copy() {
547 let blocks = arrow_array::LargeBinaryArray::from_vec(vec![&b"hello"[..], b"world"]);
550 let section_a = RecordBatch::try_from_iter([("a", Arc::new(blocks) as ArrayRef)]).unwrap();
551 let section_b = record_batch!(("b", Int64, [10i64, 20, 30, 40, 50])).unwrap();
552
553 let mut buf = Vec::new();
554 buf.extend_from_slice(&[0xABu8; 7]);
557 let mut pos = buf.len();
558 assert_eq!(7 + section_padding(7), IPC_SECTION_ALIGNMENT);
561 write_ipc_section(&mut buf, &mut pos, §ion_a).unwrap();
562 write_ipc_section(&mut buf, &mut pos, §ion_b).unwrap();
563
564 let data = aligned_bytes(&buf);
565 assert_eq!(
566 section_padding(data.as_ptr() as usize),
567 0,
568 "base not aligned"
569 );
570
571 let mut offset = 7;
572 let read_a = read_ipc_section_at(&data, &mut offset).unwrap();
573 let read_b = read_ipc_section_at(&data, &mut offset).unwrap();
574 assert_eq!(read_a, section_a);
575 assert_eq!(read_b, section_b);
576
577 let data_base = data.as_ptr() as usize;
578 let data_end = data_base + data.len();
579 for batch in [&read_a, &read_b] {
580 for buffer in batch.column(0).to_data().buffers() {
581 let ptr = buffer.as_ptr() as usize;
582 assert!(
583 ptr >= data_base && ptr < data_end,
584 "section buffer at {ptr:#x} was realigned out of the input \
585 [{data_base:#x}..{data_end:#x}) — misaligned section",
586 );
587 }
588 }
589 }
590
591 #[test]
592 fn test_aligned_multi_batch_section_roundtrip_zero_copy() {
593 let b1 = record_batch!(("v", Int64, [1i64, 2, 3])).unwrap();
596 let b2 = record_batch!(("v", Int64, [4i64, 5])).unwrap();
597 let b3 = record_batch!(("v", Int64, [6i64])).unwrap();
598
599 let mut buf = vec![0xCDu8; 5];
600 let mut pos = buf.len();
601 write_ipc_section_batches(&mut buf, &mut pos, [b1.clone(), b2.clone(), b3.clone()])
602 .unwrap();
603
604 let data = aligned_bytes(&buf);
605 let mut offset = 5;
606 let read = read_ipc_section_batches_at(&data, &mut offset).unwrap();
607 assert_eq!(read, vec![b1, b2, b3]);
608 assert_eq!(offset, buf.len(), "offset should land at section end");
609
610 let data_base = data.as_ptr() as usize;
611 let data_end = data_base + data.len();
612 for buffer in read[0].column(0).to_data().buffers() {
613 let ptr = buffer.as_ptr() as usize;
614 assert!(
615 ptr >= data_base && ptr < data_end,
616 "first batch buffer at {ptr:#x} was realigned out of the input",
617 );
618 }
619 }
620}