1use std::collections::HashMap;
30use std::io::{Read, Write};
31use std::sync::Arc;
32
33use arrow_array::RecordBatch;
34use arrow_buffer::Buffer as ArrowBuffer;
35use arrow_ipc::reader as ipc_reader;
36use arrow_ipc::writer::{write_message, DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
37use arrow_ipc::{convert as ipc_convert, root_as_message, MessageHeader};
38use arrow_schema::{Schema, SchemaRef};
39use flatbuffers::FlatBufferBuilder;
40
41use crate::errors::{Result, RpcError};
42
43pub type Metadata = HashMap<String, String>;
47
48#[inline]
50pub fn md_get<'a>(md: &'a Metadata, key: &str) -> Option<&'a str> {
51 md.get(key).map(String::as_str)
52}
53
54pub const MAX_IPC_SCHEMA_BYTES: usize = 16 * 1024 * 1024;
62
63pub const MAX_IPC_MESSAGE_BYTES: usize = 256 * 1024 * 1024;
69
70const CONTINUATION_MARKER: [u8; 4] = [0xFF, 0xFF, 0xFF, 0xFF];
75
76pub struct StreamWriter<W: Write> {
85 writer: W,
86 schema: SchemaRef,
87 opts: IpcWriteOptions,
88 data_gen: IpcDataGenerator,
89 dict_tracker: DictionaryTracker,
90 finished: bool,
91}
92
93impl<W: Write> StreamWriter<W> {
94 pub fn new(mut writer: W, schema: &Schema) -> Result<Self> {
96 let opts = IpcWriteOptions::default();
97 let data_gen = IpcDataGenerator::default();
98 let mut dict_tracker = DictionaryTracker::new(false);
99 let encoded =
100 data_gen.schema_to_bytes_with_dictionary_tracker(schema, &mut dict_tracker, &opts);
101 write_message(&mut writer, encoded, &opts)?;
102 Ok(Self {
103 writer,
104 schema: Arc::new(schema.clone()),
105 opts,
106 data_gen,
107 dict_tracker,
108 finished: false,
109 })
110 }
111
112 pub fn write(&mut self, batch: &RecordBatch, metadata: Option<&Metadata>) -> Result<()> {
116 if self.finished {
117 return Err(RpcError::new("IOError", "writer already finished"));
118 }
119 let mut ctx = Default::default();
120 let (dicts, data) = self
121 .data_gen
122 .encode(batch, &mut self.dict_tracker, &self.opts, &mut ctx)
123 .map_err(RpcError::from)?;
124 for d in dicts {
125 write_message(&mut self.writer, d, &self.opts).map_err(RpcError::from)?;
126 }
127 if let Some(md) = metadata.filter(|m| !m.is_empty()) {
128 let new_msg = repack_record_batch_message_with_metadata(&data.ipc_message, md)?;
129 let encoded = arrow_ipc::writer::EncodedData {
130 ipc_message: new_msg,
131 arrow_data: data.arrow_data,
132 };
133 write_message(&mut self.writer, encoded, &self.opts).map_err(RpcError::from)?;
134 } else {
135 write_message(&mut self.writer, data, &self.opts).map_err(RpcError::from)?;
136 }
137 Ok(())
138 }
139
140 pub fn schema(&self) -> SchemaRef {
142 self.schema.clone()
143 }
144
145 pub fn finish(&mut self) -> Result<()> {
147 if self.finished {
148 return Ok(());
149 }
150 self.writer.write_all(&CONTINUATION_MARKER)?;
151 self.writer.write_all(&[0u8; 4])?;
152 self.writer.flush()?;
153 self.finished = true;
154 Ok(())
155 }
156
157 pub fn flush(&mut self) -> Result<()> {
159 self.writer.flush()?;
160 Ok(())
161 }
162
163 pub fn get_mut(&mut self) -> &mut W {
164 &mut self.writer
165 }
166}
167
168impl<W: Write> Drop for StreamWriter<W> {
169 fn drop(&mut self) {
170 let _ = self.finish();
171 }
172}
173
174fn repack_record_batch_message_with_metadata(
177 msg_bytes: &[u8],
178 metadata: &Metadata,
179) -> Result<Vec<u8>> {
180 use arrow_ipc::{
181 Buffer as FbBuffer, FieldNode, KeyValue, KeyValueArgs, MessageBuilder, RecordBatchBuilder,
182 };
183
184 let msg = root_as_message(msg_bytes)
185 .map_err(|e| RpcError::new("IPC", format!("parsing message: {e}")))?;
186 let version = msg.version();
187 let header_type = msg.header_type();
188 let body_length = msg.bodyLength();
189 if header_type != MessageHeader::RecordBatch {
190 return Err(RpcError::new(
191 "IPC",
192 format!("repack expected RecordBatch header, got {header_type:?}"),
193 ));
194 }
195 let rb = msg
196 .header_as_record_batch()
197 .ok_or_else(|| RpcError::new("IPC", "missing RecordBatch header"))?;
198
199 let mut fbb = FlatBufferBuilder::new();
200
201 let src_nodes = rb
202 .nodes()
203 .ok_or_else(|| RpcError::new("IPC", "RecordBatch missing nodes"))?;
204 let nodes: Vec<FieldNode> = src_nodes.iter().copied().collect();
205 let nodes_vec = fbb.create_vector(&nodes);
206
207 let src_buffers = rb
208 .buffers()
209 .ok_or_else(|| RpcError::new("IPC", "RecordBatch missing buffers"))?;
210 let buffers: Vec<FbBuffer> = src_buffers.iter().copied().collect();
211 let buffers_vec = fbb.create_vector(&buffers);
212
213 let variadic_vec = rb.variadicBufferCounts().map(|v| {
214 let counts: Vec<i64> = v.iter().collect();
215 fbb.create_vector(&counts)
216 });
217
218 let new_rb = {
219 let mut b = RecordBatchBuilder::new(&mut fbb);
220 b.add_length(rb.length());
221 b.add_nodes(nodes_vec);
222 b.add_buffers(buffers_vec);
223 if let Some(v) = variadic_vec {
224 b.add_variadicBufferCounts(v);
225 }
226 b.finish()
229 };
230
231 let kvs: Vec<_> = metadata
235 .iter()
236 .map(|(k, v)| {
237 let k_off = fbb.create_string(k);
238 let v_off = fbb.create_string(v);
239 KeyValue::create(
240 &mut fbb,
241 &KeyValueArgs {
242 key: Some(k_off),
243 value: Some(v_off),
244 },
245 )
246 })
247 .collect();
248 let md_vec = fbb.create_vector(&kvs);
249
250 let mut mb = MessageBuilder::new(&mut fbb);
251 mb.add_version(version);
252 mb.add_header_type(header_type);
253 mb.add_header(new_rb.as_union_value());
254 mb.add_bodyLength(body_length);
255 mb.add_custom_metadata(md_vec);
256 let m = mb.finish();
257 fbb.finish(m, None);
258 Ok(fbb.finished_data().to_vec())
259}
260
261pub struct StreamReader<R: Read> {
271 reader: R,
272 schema: SchemaRef,
273 dictionaries: HashMap<i64, arrow_array::ArrayRef>,
274 finished: bool,
275 relaxed_schema: Option<SchemaRef>,
280}
281
282impl<R: Read> StreamReader<R> {
283 pub fn new(mut reader: R) -> Result<Self> {
290 let msg = read_message_bytes(&mut reader, MAX_IPC_SCHEMA_BYTES)?
291 .ok_or_else(|| RpcError::new("IPC", "empty IPC stream (no schema)"))?;
292 let msg_fb = root_as_message(&msg.message_bytes)
293 .map_err(|e| RpcError::new("IPC", format!("parse schema message: {e}")))?;
294 if msg_fb.header_type() != MessageHeader::Schema {
295 return Err(RpcError::new(
296 "IPC",
297 format!("expected Schema, got {:?}", msg_fb.header_type()),
298 ));
299 }
300 let ipc_schema = msg_fb
301 .header_as_schema()
302 .ok_or_else(|| RpcError::new("IPC", "bad schema header"))?;
303 let schema = ipc_convert::fb_to_schema(ipc_schema);
304 Ok(Self {
305 reader,
306 schema: Arc::new(schema),
307 dictionaries: HashMap::new(),
308 finished: false,
309 relaxed_schema: None,
310 })
311 }
312
313 pub fn schema(&self) -> SchemaRef {
316 self.relaxed_schema
317 .clone()
318 .unwrap_or_else(|| self.schema.clone())
319 }
320
321 pub fn relax_nullability(mut self) -> Self {
327 self.relaxed_schema = Some(Arc::new(relax_schema_nullability(self.schema.as_ref())));
328 self
329 }
330
331 pub fn read_next(&mut self) -> Result<Option<(RecordBatch, Metadata)>> {
336 if self.finished {
337 return Ok(None);
338 }
339 loop {
340 let msg = match read_message_bytes(&mut self.reader, MAX_IPC_MESSAGE_BYTES)? {
341 Some(m) => m,
342 None => {
343 self.finished = true;
344 return Ok(None);
345 }
346 };
347 let msg_fb = root_as_message(&msg.message_bytes)
348 .map_err(|e| RpcError::new("IPC", format!("parse message: {e}")))?;
349 let version = msg_fb.version();
350 match msg_fb.header_type() {
351 MessageHeader::DictionaryBatch => {
352 let dict = msg_fb
353 .header_as_dictionary_batch()
354 .ok_or_else(|| RpcError::new("IPC", "bad dictionary header"))?;
355 let body_buf = ArrowBuffer::from_vec(msg.body);
356 if let Some(data) = dict.data() {
360 validate_record_batch_buffers(&data, body_buf.len())?;
361 }
362 decode_guard("dictionary batch", || {
366 ipc_reader::read_dictionary(
367 &body_buf,
368 dict,
369 self.schema.as_ref(),
370 &mut self.dictionaries,
371 &version,
372 )
373 })?
374 .map_err(RpcError::from)?;
375 }
376 MessageHeader::RecordBatch => {
377 let rb_fb = msg_fb
378 .header_as_record_batch()
379 .ok_or_else(|| RpcError::new("IPC", "bad record batch header"))?;
380 let body_buf = ArrowBuffer::from_vec(msg.body);
381 validate_record_batch_buffers(&rb_fb, body_buf.len())?;
382 let decode_schema = self
389 .relaxed_schema
390 .clone()
391 .unwrap_or_else(|| self.schema.clone());
392 let batch = decode_guard("record batch", || {
393 ipc_reader::read_record_batch(
394 &body_buf,
395 rb_fb,
396 decode_schema,
397 &self.dictionaries,
398 None,
399 &version,
400 )
401 })?
402 .map_err(RpcError::from)?;
403 let metadata = parse_custom_metadata(&msg_fb);
404 return Ok(Some((batch, metadata)));
405 }
406 MessageHeader::Schema => {
407 return Err(RpcError::new("IPC", "unexpected schema message mid-stream"));
408 }
409 MessageHeader::NONE => continue,
410 other => {
411 return Err(RpcError::new(
412 "IPC",
413 format!("unsupported message type {other:?}"),
414 ));
415 }
416 }
417 }
418 }
419
420 pub fn drain(&mut self) -> Result<()> {
422 while self.read_next()?.is_some() {}
423 Ok(())
424 }
425
426 pub fn get_mut(&mut self) -> &mut R {
427 &mut self.reader
428 }
429}
430
431fn parse_custom_metadata(msg: &arrow_ipc::Message) -> Metadata {
432 let mut out = Metadata::new();
433 if let Some(md) = msg.custom_metadata() {
434 for kv in md.iter() {
435 let k = kv.key().unwrap_or("").to_string();
436 let v = kv.value().unwrap_or("").to_string();
437 out.insert(k, v);
438 }
439 }
440 out
441}
442
443fn validate_record_batch_buffers(rb: &arrow_ipc::RecordBatch, body_len: usize) -> Result<()> {
451 if let Some(buffers) = rb.buffers() {
452 for buf in buffers.iter() {
453 let offset = buf.offset();
454 let length = buf.length();
455 if offset < 0 || length < 0 {
456 return Err(RpcError::new("IPC", "negative IPC buffer descriptor"));
457 }
458 let end = (offset as u64)
459 .checked_add(length as u64)
460 .ok_or_else(|| RpcError::new("IPC", "IPC buffer descriptor overflows"))?;
461 if end > body_len as u64 {
462 return Err(RpcError::new(
463 "IPC",
464 "IPC buffer descriptor exceeds message body",
465 ));
466 }
467 }
468 }
469 Ok(())
470}
471
472fn decode_guard<T>(what: &str, f: impl FnOnce() -> T) -> Result<T> {
477 std::panic::catch_unwind(std::panic::AssertUnwindSafe(f))
478 .map_err(|_| RpcError::new("IPC", format!("panic decoding {what} (malformed frame)")))
479}
480
481struct RawMessage {
482 message_bytes: Vec<u8>,
483 body: Vec<u8>,
484}
485
486fn read_exact(r: &mut impl Read, buf: &mut [u8]) -> Result<bool> {
487 let mut read = 0;
488 while read < buf.len() {
489 match r.read(&mut buf[read..]) {
490 Ok(0) => {
491 if read == 0 {
492 return Ok(false);
493 }
494 return Err(RpcError::new("IOError", "unexpected EOF in IPC message"));
495 }
496 Ok(n) => read += n,
497 Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
498 Err(e) => return Err(e.into()),
499 }
500 }
501 Ok(true)
502}
503
504fn read_message_bytes(r: &mut impl Read, max_bytes: usize) -> Result<Option<RawMessage>> {
508 let mut prefix = [0u8; 4];
509 if !read_exact(r, &mut prefix)? {
510 return Ok(None);
511 }
512 let size_bytes = if prefix == CONTINUATION_MARKER {
513 let mut sb = [0u8; 4];
514 if !read_exact(r, &mut sb)? {
515 return Ok(None);
516 }
517 sb
518 } else {
519 prefix
520 };
521 let size = u32::from_le_bytes(size_bytes) as usize;
522 if size == 0 {
523 return Ok(None);
525 }
526 if size > max_bytes {
527 return Err(RpcError::new(
528 "IPC",
529 format!(
530 "IPC message header length {size} bytes exceeds cap {max_bytes} — \
531 refusing to allocate before parsing"
532 ),
533 ));
534 }
535 let mut message_bytes = vec![0u8; size];
536 if !read_exact(r, &mut message_bytes)? {
537 return Err(RpcError::new("IOError", "unexpected EOF in message body"));
538 }
539 let msg = root_as_message(&message_bytes)
543 .map_err(|e| RpcError::new("IPC", format!("parse message header: {e}")))?;
544 let body_length_signed = msg.bodyLength();
545 if body_length_signed < 0 {
546 return Err(RpcError::new(
547 "IPC",
548 format!("IPC message has negative bodyLength ({body_length_signed})"),
549 ));
550 }
551 let body_length = body_length_signed as usize;
552 if body_length > max_bytes {
553 return Err(RpcError::new(
554 "IPC",
555 format!(
556 "IPC message bodyLength {body_length} bytes exceeds cap {max_bytes} — \
557 refusing to allocate before parsing"
558 ),
559 ));
560 }
561 let mut body = vec![0u8; body_length];
562 if body_length > 0 && !read_exact(r, &mut body)? {
563 return Err(RpcError::new("IOError", "unexpected EOF in message body"));
564 }
565 Ok(Some(RawMessage {
566 message_bytes,
567 body,
568 }))
569}
570
571pub fn write_one_batch(batch: &RecordBatch, metadata: Option<&Metadata>) -> Result<Vec<u8>> {
578 let schema = batch.schema();
579 let mut buf = Vec::new();
580 {
581 let mut w = StreamWriter::new(&mut buf, schema.as_ref())?;
582 w.write(batch, metadata)?;
583 w.finish()?;
584 }
585 Ok(buf)
586}
587
588pub(crate) fn bytes_to_hex(bytes: &[u8]) -> String {
591 const HEX: &[u8; 16] = b"0123456789abcdef";
592 let mut out = String::with_capacity(bytes.len() * 2);
593 for b in bytes {
594 out.push(HEX[(b >> 4) as usize] as char);
595 out.push(HEX[(b & 0x0f) as usize] as char);
596 }
597 out
598}
599
600fn relax_field_nullability(f: &arrow_schema::Field) -> arrow_schema::Field {
601 use arrow_schema::DataType;
602 let dt = match f.data_type() {
603 DataType::List(inner) => DataType::List(Arc::new(relax_field_nullability(inner))),
604 DataType::LargeList(inner) => DataType::LargeList(Arc::new(relax_field_nullability(inner))),
605 DataType::FixedSizeList(inner, n) => {
606 DataType::FixedSizeList(Arc::new(relax_field_nullability(inner)), *n)
607 }
608 DataType::Struct(fields) => DataType::Struct(
609 fields
610 .iter()
611 .map(|child| Arc::new(relax_field_nullability(child)))
612 .collect(),
613 ),
614 other => other.clone(),
618 };
619 #[allow(deprecated)]
620 let new_field = if let DataType::Dictionary(_, _) = f.data_type() {
621 arrow_schema::Field::new_dict(
622 f.name(),
623 dt,
624 true,
625 f.dict_id().unwrap_or(0),
626 f.dict_is_ordered().unwrap_or(false),
627 )
628 } else {
629 arrow_schema::Field::new(f.name(), dt, true)
630 };
631 new_field.with_metadata(f.metadata().clone())
632}
633
634fn relax_schema_nullability(s: &Schema) -> Schema {
635 let new_fields: Vec<arrow_schema::Field> = s
636 .fields()
637 .iter()
638 .map(|f| relax_field_nullability(f))
639 .collect();
640 Schema::new_with_metadata(new_fields, s.metadata().clone())
641}
642
643pub fn empty_batch(schema: &Schema) -> Result<RecordBatch> {
645 use arrow_array::array::new_empty_array;
646 use arrow_array::RecordBatchOptions;
647 let cols: Vec<arrow_array::ArrayRef> = schema
648 .fields()
649 .iter()
650 .map(|f| new_empty_array(f.data_type()))
651 .collect();
652 RecordBatch::try_new_with_options(
653 Arc::new(schema.clone()),
654 cols,
655 &RecordBatchOptions::new().with_row_count(Some(0)),
656 )
657 .map_err(RpcError::from)
658}
659
660#[cfg(test)]
661mod tests {
662 use super::*;
663 use arrow_array::{Int64Array, StringArray};
664 use arrow_schema::{DataType, Field};
665
666 #[test]
667 fn roundtrip_with_metadata() {
668 let schema = Schema::new(vec![
669 Field::new("idx", DataType::Int64, false),
670 Field::new("name", DataType::Utf8, false),
671 ]);
672 let batch = RecordBatch::try_new(
673 Arc::new(schema.clone()),
674 vec![
675 Arc::new(Int64Array::from(vec![1, 2, 3])) as _,
676 Arc::new(StringArray::from(vec!["a", "b", "c"])) as _,
677 ],
678 )
679 .unwrap();
680
681 let mut buf: Vec<u8> = Vec::new();
682 {
683 let mut w = StreamWriter::new(&mut buf, &schema).unwrap();
684 let mut md = Metadata::new();
685 md.insert("vgi_rpc.method".into(), "echo_string".into());
686 w.write(&batch, Some(&md)).unwrap();
687 w.finish().unwrap();
688 }
689
690 let mut r = StreamReader::new(buf.as_slice()).unwrap();
691 let (rb, md) = r.read_next().unwrap().expect("batch");
692 assert_eq!(rb.num_rows(), 3);
693 assert_eq!(md_get(&md, "vgi_rpc.method"), Some("echo_string"));
694 assert!(r.read_next().unwrap().is_none());
695 }
696
697 #[test]
698 fn zero_row_metadata_only() {
699 let schema = Schema::empty();
700 let batch = empty_batch(&schema).unwrap();
701
702 let mut buf: Vec<u8> = Vec::new();
703 {
704 let mut w = StreamWriter::new(&mut buf, &schema).unwrap();
705 let mut md = Metadata::new();
706 md.insert("vgi_rpc.log_level".into(), "INFO".into());
707 w.write(&batch, Some(&md)).unwrap();
708 w.finish().unwrap();
709 }
710 let mut r = StreamReader::new(buf.as_slice()).unwrap();
711 let (rb, md) = r.read_next().unwrap().expect("batch");
712 assert_eq!(rb.num_rows(), 0);
713 assert_eq!(md_get(&md, "vgi_rpc.log_level"), Some("INFO"));
714 }
715
716 #[test]
717 fn rejects_oversize_schema_length_prefix() {
718 let bomb: &[u8] = &[0x1A, 0x2C, 0xF5, 0x2C];
722 let err = StreamReader::new(bomb).err().expect("must reject");
723 assert!(
724 err.message.contains("exceeds cap"),
725 "unexpected error: {err:?}"
726 );
727 }
728
729 #[test]
730 fn rejects_oversize_message_bodylength() {
731 use arrow_ipc::{Buffer as FbBuffer, FieldNode, MessageBuilder, RecordBatchBuilder};
735 let schema = Schema::new(vec![Field::new("v", DataType::Int64, false)]);
737 let mut buf: Vec<u8> = Vec::new();
738 {
739 let w = StreamWriter::new(&mut buf, &schema).unwrap();
740 std::mem::forget(w);
744 }
745 let mut fbb = FlatBufferBuilder::new();
748 let nodes_vec = fbb.create_vector(&[FieldNode::new(0, 0)]);
749 let buffers_vec = fbb.create_vector(&[FbBuffer::new(0, 0)]);
750 let rb_off = {
751 let mut b = RecordBatchBuilder::new(&mut fbb);
752 b.add_length(0);
753 b.add_nodes(nodes_vec);
754 b.add_buffers(buffers_vec);
755 b.finish()
756 };
757 let msg_off = {
758 let mut mb = MessageBuilder::new(&mut fbb);
759 mb.add_version(arrow_ipc::MetadataVersion::V5);
760 mb.add_header_type(MessageHeader::RecordBatch);
761 mb.add_header(rb_off.as_union_value());
762 mb.add_bodyLength(MAX_IPC_MESSAGE_BYTES as i64 + 1);
763 mb.finish()
764 };
765 fbb.finish(msg_off, None);
766 let msg_bytes = fbb.finished_data();
767 buf.extend_from_slice(&CONTINUATION_MARKER);
769 buf.extend_from_slice(&(msg_bytes.len() as u32).to_le_bytes());
770 buf.extend_from_slice(msg_bytes);
771 let mut r = StreamReader::new(buf.as_slice()).unwrap();
774 let err = r.read_next().expect_err("must reject");
775 assert!(
776 err.message.contains("bodyLength") && err.message.contains("exceeds cap"),
777 "unexpected error: {err:?}"
778 );
779 }
780
781 #[test]
782 fn rejects_buffer_descriptor_past_body() {
783 use arrow_ipc::{Buffer as FbBuffer, FieldNode, MessageBuilder, RecordBatchBuilder};
788 let schema = Schema::new(vec![Field::new("v", DataType::Int64, false)]);
789 let mut buf: Vec<u8> = Vec::new();
790 {
791 let w = StreamWriter::new(&mut buf, &schema).unwrap();
792 std::mem::forget(w);
793 }
794 let mut fbb = FlatBufferBuilder::new();
795 let nodes_vec = fbb.create_vector(&[FieldNode::new(1, 0)]);
796 let buffers_vec = fbb.create_vector(&[FbBuffer::new(0, 1000)]);
798 let rb_off = {
799 let mut b = RecordBatchBuilder::new(&mut fbb);
800 b.add_length(1);
801 b.add_nodes(nodes_vec);
802 b.add_buffers(buffers_vec);
803 b.finish()
804 };
805 let msg_off = {
806 let mut mb = MessageBuilder::new(&mut fbb);
807 mb.add_version(arrow_ipc::MetadataVersion::V5);
808 mb.add_header_type(MessageHeader::RecordBatch);
809 mb.add_header(rb_off.as_union_value());
810 mb.add_bodyLength(8);
811 mb.finish()
812 };
813 fbb.finish(msg_off, None);
814 let msg_bytes = fbb.finished_data().to_vec();
815 buf.extend_from_slice(&CONTINUATION_MARKER);
816 buf.extend_from_slice(&(msg_bytes.len() as u32).to_le_bytes());
817 buf.extend_from_slice(&msg_bytes);
818 buf.extend_from_slice(&[0u8; 8]); let mut r = StreamReader::new(buf.as_slice()).unwrap();
821 let err = r.read_next().expect_err("must reject");
822 assert!(
823 err.message.contains("buffer descriptor"),
824 "unexpected error: {err:?}"
825 );
826 }
827}