1use std::borrow::{Borrow, Cow};
2
3use arrow_format::ipc;
4use arrow_format::ipc::KeyValue;
5use arrow_format::ipc::planus::Builder;
6use bytes::Bytes;
7use polars_error::{PolarsResult, polars_bail, polars_err};
8use polars_utils::compression::ZstdLevel;
9
10use super::super::IpcField;
11use super::write;
12use crate::array::*;
13use crate::datatypes::*;
14use crate::io::ipc::endianness::is_native_little_endian;
15use crate::io::ipc::read::Dictionaries;
16use crate::legacy::prelude::LargeListArray;
17use crate::match_integer_type;
18use crate::record_batch::RecordBatchT;
19use crate::types::Index;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23pub enum Compression {
24 LZ4,
26 ZSTD(ZstdLevel),
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
32pub struct WriteOptions {
33 pub compression: Option<Compression>,
36}
37
38pub fn dictionaries_to_encode(
40 field: &IpcField,
41 array: &dyn Array,
42 dictionary_tracker: &mut DictionaryTracker,
43 dicts_to_encode: &mut Vec<(i64, Box<dyn Array>)>,
44) -> PolarsResult<()> {
45 use PhysicalType::*;
46 match array.dtype().to_physical_type() {
47 Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null
48 | FixedSizeBinary | BinaryView | Utf8View => Ok(()),
49 Dictionary(key_type) => match_integer_type!(key_type, |$T| {
50 let dict_id = field.dictionary_id
51 .ok_or_else(|| polars_err!(InvalidOperation: "Dictionaries must have an associated id"))?;
52
53 if dictionary_tracker.insert(dict_id, array)? {
54 dicts_to_encode.push((dict_id, array.to_boxed()));
55 }
56
57 let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
58 let values = array.values();
59 dictionaries_to_encode(field,
61 values.as_ref(),
62 dictionary_tracker,
63 dicts_to_encode,
64 )?;
65
66 Ok(())
67 }),
68 Struct => {
69 let array = array.as_any().downcast_ref::<StructArray>().unwrap();
70 let fields = field.fields.as_slice();
71 if array.fields().len() != fields.len() {
72 polars_bail!(InvalidOperation: "The number of fields in a struct must equal the number of children in IpcField");
73 }
74 fields
75 .iter()
76 .zip(array.values().iter())
77 .try_for_each(|(field, values)| {
78 dictionaries_to_encode(
79 field,
80 values.as_ref(),
81 dictionary_tracker,
82 dicts_to_encode,
83 )
84 })
85 },
86 List => {
87 let values = array
88 .as_any()
89 .downcast_ref::<ListArray<i32>>()
90 .unwrap()
91 .values();
92 let field = field.fields.first().ok_or_else(|| polars_err!(ComputeError: "Invalid IPC field structure: expected nested field but fields vector is empty"))?;
93 dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
94 },
95 LargeList => {
96 let values = array
97 .as_any()
98 .downcast_ref::<ListArray<i64>>()
99 .unwrap()
100 .values();
101 let field = field.fields.first().ok_or_else(|| polars_err!(ComputeError: "Invalid IPC field structure: expected nested field but fields vector is empty"))?;
102 dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
103 },
104 FixedSizeList => {
105 let values = array
106 .as_any()
107 .downcast_ref::<FixedSizeListArray>()
108 .unwrap()
109 .values();
110 let field = field.fields.first().ok_or_else(|| polars_err!(ComputeError: "Invalid IPC field structure: expected nested field but fields vector is empty"))?;
111 dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
112 },
113 Union => {
114 let values = array
115 .as_any()
116 .downcast_ref::<UnionArray>()
117 .unwrap()
118 .fields();
119 let fields = field.fields.as_slice();
120 if values.len() != fields.len() {
121 polars_bail!(InvalidOperation:
122 "The number of fields in a union must equal the number of children in IpcField"
123 );
124 }
125 fields
126 .iter()
127 .zip(values.iter())
128 .try_for_each(|(field, values)| {
129 dictionaries_to_encode(
130 field,
131 values.as_ref(),
132 dictionary_tracker,
133 dicts_to_encode,
134 )
135 })
136 },
137 Map => {
138 let values = array.as_any().downcast_ref::<MapArray>().unwrap().field();
139 let field = field.fields.first().ok_or_else(|| polars_err!(ComputeError: "Invalid IPC field structure: expected nested field but fields vector is empty"))?;
140 dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
141 },
142 }
143}
144
145pub fn encode_dictionary(
151 dict_id: i64,
152 array: &dyn Array,
153 options: &WriteOptions,
154) -> PolarsResult<EncodedData> {
155 let PhysicalType::Dictionary(key_type) = array.dtype().to_physical_type() else {
156 panic!("Given array is not a DictionaryArray")
157 };
158
159 match_integer_type!(key_type, |$T| {
160 let array: &DictionaryArray<$T> = array.as_any().downcast_ref().unwrap();
161
162 encode_dictionary_values(dict_id, array.values().as_ref(), options)
163 })
164}
165
166pub fn encode_new_dictionaries(
167 field: &IpcField,
168 array: &dyn Array,
169 options: &WriteOptions,
170 dictionary_tracker: &mut DictionaryTracker,
171 encoded_dictionaries: &mut Vec<EncodedData>,
172) -> PolarsResult<()> {
173 let mut dicts_to_encode = Vec::new();
174 dictionaries_to_encode(field, array, dictionary_tracker, &mut dicts_to_encode)?;
175 for (dict_id, dict_array) in dicts_to_encode {
176 encoded_dictionaries.push(encode_dictionary(dict_id, dict_array.as_ref(), options)?);
177 }
178 Ok(())
179}
180
181pub fn encode_chunk(
182 chunk: &RecordBatchT<Box<dyn Array>>,
183 fields: &[IpcField],
184 dictionary_tracker: &mut DictionaryTracker,
185 options: &WriteOptions,
186) -> PolarsResult<(Vec<EncodedData>, EncodedData)> {
187 let mut encoded_message = EncodedData::default();
188 let encoded_dictionaries = encode_chunk_amortized(
189 chunk,
190 fields,
191 dictionary_tracker,
192 options,
193 &mut encoded_message,
194 )?;
195 Ok((encoded_dictionaries, encoded_message))
196}
197
198pub fn encode_chunk_amortized(
200 chunk: &RecordBatchT<Box<dyn Array>>,
201 fields: &[IpcField],
202 dictionary_tracker: &mut DictionaryTracker,
203 options: &WriteOptions,
204 encoded_message: &mut EncodedData,
205) -> PolarsResult<Vec<EncodedData>> {
206 let mut encoded_dictionaries = vec![];
207
208 for (field, array) in fields.iter().zip(chunk.as_ref()) {
209 encode_new_dictionaries(
210 field,
211 array.as_ref(),
212 options,
213 dictionary_tracker,
214 &mut encoded_dictionaries,
215 )?;
216 }
217 encode_record_batch(chunk, options, encoded_message);
218
219 Ok(encoded_dictionaries)
220}
221
222fn serialize_compression(
223 compression: Option<Compression>,
224) -> Option<Box<arrow_format::ipc::BodyCompression>> {
225 if let Some(compression) = compression {
226 let codec = match compression {
227 Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame,
228 Compression::ZSTD(_) => arrow_format::ipc::CompressionType::Zstd,
229 };
230 Some(Box::new(arrow_format::ipc::BodyCompression {
231 codec,
232 method: arrow_format::ipc::BodyCompressionMethod::Buffer,
233 }))
234 } else {
235 None
236 }
237}
238
239fn set_variadic_buffer_counts(counts: &mut Vec<i64>, array: &dyn Array) {
240 match array.dtype().to_storage() {
241 ArrowDataType::Utf8View => {
242 let array = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
243 counts.push(array.data_buffers().len() as i64);
244 },
245 ArrowDataType::BinaryView => {
246 let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
247 counts.push(array.data_buffers().len() as i64);
248 },
249 ArrowDataType::Struct(_) => {
250 let array = array.as_any().downcast_ref::<StructArray>().unwrap();
251 for array in array.values() {
252 set_variadic_buffer_counts(counts, array.as_ref())
253 }
254 },
255 ArrowDataType::LargeList(_) => {
256 let array = array.as_any().downcast_ref::<LargeListArray>().unwrap();
259 let offsets = array.offsets().buffer();
260 let first = *offsets.first().unwrap();
261 let last = *offsets.last().unwrap();
262 let subslice = array
263 .values()
264 .sliced(first.to_usize(), last.to_usize() - first.to_usize());
265 set_variadic_buffer_counts(counts, &*subslice)
266 },
267 ArrowDataType::FixedSizeList(_, _) => {
268 let array = array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
269 set_variadic_buffer_counts(counts, array.values().as_ref())
270 },
271 ArrowDataType::Dictionary(_, _, _) => (),
274 _ => (),
275 }
276}
277
278fn gc_bin_view<'a, T: ViewType + ?Sized>(
279 arr: &'a Box<dyn Array>,
280 concrete_arr: &'a BinaryViewArrayGeneric<T>,
281) -> Cow<'a, Box<dyn Array>> {
282 let bytes_len = concrete_arr.total_bytes_len();
283 let buffer_len = concrete_arr.total_buffer_len();
284 let extra_len = buffer_len.saturating_sub(bytes_len);
285 if extra_len < bytes_len.min(1024) {
286 Cow::Borrowed(arr)
288 } else {
289 Cow::Owned(concrete_arr.clone().gc().boxed())
291 }
292}
293
294pub fn encode_array(
295 array: &Box<dyn Array>,
296 options: &WriteOptions,
297 variadic_buffer_counts: &mut Vec<i64>,
298 buffers: &mut Vec<ipc::Buffer>,
299 arrow_data: &mut Vec<u8>,
300 nodes: &mut Vec<ipc::FieldNode>,
301 offset: &mut i64,
302) {
303 let array = match array.dtype() {
305 ArrowDataType::BinaryView => {
306 let concrete_arr = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
307 gc_bin_view(array, concrete_arr)
308 },
309 ArrowDataType::Utf8View => {
310 let concrete_arr = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
311 gc_bin_view(array, concrete_arr)
312 },
313 _ => Cow::Borrowed(array),
314 };
315 let array = array.as_ref().as_ref();
316
317 set_variadic_buffer_counts(variadic_buffer_counts, array);
318
319 write(
320 array,
321 buffers,
322 arrow_data,
323 nodes,
324 offset,
325 is_native_little_endian(),
326 options.compression,
327 )
328}
329
330pub fn encode_record_batch(
333 chunk: &RecordBatchT<Box<dyn Array>>,
334 options: &WriteOptions,
335 encoded_message: &mut EncodedData,
336) {
337 let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
338 let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
339 encoded_message.arrow_data.clear();
340
341 let mut offset = 0;
342 let mut variadic_buffer_counts = vec![];
343 for array in chunk.arrays() {
344 encode_array(
345 array,
346 options,
347 &mut variadic_buffer_counts,
348 &mut buffers,
349 &mut encoded_message.arrow_data,
350 &mut nodes,
351 &mut offset,
352 );
353 }
354
355 commit_encoded_arrays(
356 chunk.len(),
357 options,
358 variadic_buffer_counts,
359 buffers,
360 nodes,
361 None,
362 encoded_message,
363 );
364}
365
366pub fn commit_encoded_arrays(
367 array_len: usize,
368 options: &WriteOptions,
369 variadic_buffer_counts: Vec<i64>,
370 buffers: Vec<ipc::Buffer>,
371 nodes: Vec<ipc::FieldNode>,
372 custom_metadata: Option<Vec<KeyValue>>,
373 encoded_message: &mut EncodedData,
374) {
375 let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
376 None
377 } else {
378 Some(variadic_buffer_counts)
379 };
380
381 let compression = serialize_compression(options.compression);
382
383 let message = arrow_format::ipc::Message {
384 version: arrow_format::ipc::MetadataVersion::V5,
385 header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new(
386 arrow_format::ipc::RecordBatch {
387 length: array_len as i64,
388 nodes: Some(nodes),
389 buffers: Some(buffers),
390 compression,
391 variadic_buffer_counts,
392 },
393 ))),
394 body_length: encoded_message.arrow_data.len() as i64,
395 custom_metadata,
396 };
397
398 let mut builder = Builder::new();
399 let ipc_message = builder.finish(&message, None);
400 encoded_message.ipc_message = ipc_message.to_vec();
401}
402
403pub fn encode_dictionary_values(
404 dict_id: i64,
405 values_array: &dyn Array,
406 options: &WriteOptions,
407) -> PolarsResult<EncodedData> {
408 let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
409 let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
410 let mut arrow_data: Vec<u8> = vec![];
411 let mut variadic_buffer_counts = vec![];
412 set_variadic_buffer_counts(&mut variadic_buffer_counts, values_array);
413
414 let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
415 None
416 } else {
417 Some(variadic_buffer_counts)
418 };
419
420 write(
421 values_array,
422 &mut buffers,
423 &mut arrow_data,
424 &mut nodes,
425 &mut 0,
426 is_native_little_endian(),
427 options.compression,
428 );
429
430 let compression = serialize_compression(options.compression);
431
432 let message = arrow_format::ipc::Message {
433 version: arrow_format::ipc::MetadataVersion::V5,
434 header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new(
435 arrow_format::ipc::DictionaryBatch {
436 id: dict_id,
437 data: Some(Box::new(arrow_format::ipc::RecordBatch {
438 length: values_array.len() as i64,
439 nodes: Some(nodes),
440 buffers: Some(buffers),
441 compression,
442 variadic_buffer_counts,
443 })),
444 is_delta: false,
445 },
446 ))),
447 body_length: arrow_data.len() as i64,
448 custom_metadata: None,
449 };
450
451 let mut builder = Builder::new();
452 let ipc_message = builder.finish(&message, None);
453
454 Ok(EncodedData {
455 ipc_message: ipc_message.to_vec(),
456 arrow_data,
457 })
458}
459
460pub struct DictionaryTracker {
464 pub dictionaries: Dictionaries,
465 pub cannot_replace: bool,
466}
467
468impl DictionaryTracker {
469 pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> PolarsResult<bool> {
479 let values = match array.dtype().to_storage() {
480 ArrowDataType::Dictionary(key_type, _, _) => {
481 match_integer_type!(key_type, |$T| {
482 let array = array
483 .as_any()
484 .downcast_ref::<DictionaryArray<$T>>()
485 .unwrap();
486 array.values()
487 })
488 },
489 _ => unreachable!(),
490 };
491
492 if let Some(last) = self.dictionaries.get(&dict_id) {
494 if last.as_ref() == values.as_ref() {
495 return Ok(false);
497 } else if self.cannot_replace {
498 polars_bail!(InvalidOperation:
499 "Dictionary replacement detected when writing IPC file format. \
500 Arrow IPC files only support a single dictionary for a given field \
501 across all batches."
502 );
503 }
504 };
505
506 self.dictionaries.insert(dict_id, values.clone());
507 Ok(true)
508 }
509}
510
511#[derive(Debug, Default)]
513pub struct EncodedData {
514 pub ipc_message: Vec<u8>,
516 pub arrow_data: Vec<u8>,
518}
519
520#[derive(Debug, Default)]
522pub struct EncodedDataBytes {
523 pub ipc_message: Bytes,
525 pub arrow_data: Bytes,
527}
528
529#[inline]
531pub(crate) fn pad_to_64(len: usize) -> usize {
532 ((len + 63) & !63) - len
533}
534
535#[derive(Debug, Clone, PartialEq)]
537pub struct Record<'a> {
538 columns: Cow<'a, RecordBatchT<Box<dyn Array>>>,
539 fields: Option<Cow<'a, [IpcField]>>,
540}
541
542impl Record<'_> {
543 pub fn fields(&self) -> Option<&[IpcField]> {
545 self.fields.as_deref()
546 }
547
548 pub fn columns(&self) -> &RecordBatchT<Box<dyn Array>> {
550 self.columns.borrow()
551 }
552}
553
554impl From<RecordBatchT<Box<dyn Array>>> for Record<'static> {
555 fn from(columns: RecordBatchT<Box<dyn Array>>) -> Self {
556 Self {
557 columns: Cow::Owned(columns),
558 fields: None,
559 }
560 }
561}
562
563impl<'a, F> From<(RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
564where
565 F: Into<Cow<'a, [IpcField]>>,
566{
567 fn from((columns, fields): (RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
568 Self {
569 columns: Cow::Owned(columns),
570 fields: fields.map(|f| f.into()),
571 }
572 }
573}
574
575impl<'a, F> From<(&'a RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
576where
577 F: Into<Cow<'a, [IpcField]>>,
578{
579 fn from((columns, fields): (&'a RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
580 Self {
581 columns: Cow::Borrowed(columns),
582 fields: fields.map(|f| f.into()),
583 }
584 }
585}
586
587pub fn arrow_ipc_block(
589 offset: usize,
590 meta_data_length: usize,
591 body_length: usize,
592) -> arrow_format::ipc::Block {
593 arrow_format::ipc::Block {
594 offset: i64::try_from(offset).unwrap(),
595 meta_data_length: i32::try_from(meta_data_length).unwrap(),
596 body_length: i64::try_from(body_length).unwrap(),
597 }
598}