Skip to main content

polars_arrow/io/ipc/write/
common.rs

1use std::borrow::{Borrow, Cow};
2
3use arrow_format::ipc;
4use arrow_format::ipc::KeyValue;
5use arrow_format::ipc::planus::Builder;
6use bytes::Bytes;
7use polars_error::{PolarsResult, polars_bail, polars_err};
8use polars_utils::compression::ZstdLevel;
9
10use super::super::IpcField;
11use super::write;
12use crate::array::*;
13use crate::datatypes::*;
14use crate::io::ipc::endianness::is_native_little_endian;
15use crate::io::ipc::read::Dictionaries;
16use crate::legacy::prelude::LargeListArray;
17use crate::match_integer_type;
18use crate::record_batch::RecordBatchT;
19use crate::types::Index;
20
21/// Compression codec
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23pub enum Compression {
24    /// LZ4 (framed)
25    LZ4,
26    /// ZSTD
27    ZSTD(ZstdLevel),
28}
29
30/// Options declaring the behaviour of writing to IPC
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
32pub struct WriteOptions {
33    /// Whether the buffers should be compressed and which codec to use.
34    /// Note: to use compression the crate must be compiled with feature `io_ipc_compression`.
35    pub compression: Option<Compression>,
36}
37
38/// Find the dictionary that are new and need to be encoded.
39pub fn dictionaries_to_encode(
40    field: &IpcField,
41    array: &dyn Array,
42    dictionary_tracker: &mut DictionaryTracker,
43    dicts_to_encode: &mut Vec<(i64, Box<dyn Array>)>,
44) -> PolarsResult<()> {
45    use PhysicalType::*;
46    match array.dtype().to_physical_type() {
47        Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null
48        | FixedSizeBinary | BinaryView | Utf8View => Ok(()),
49        Dictionary(key_type) => match_integer_type!(key_type, |$T| {
50            let dict_id = field.dictionary_id
51                .ok_or_else(|| polars_err!(InvalidOperation: "Dictionaries must have an associated id"))?;
52
53            if dictionary_tracker.insert(dict_id, array)? {
54                dicts_to_encode.push((dict_id, array.to_boxed()));
55            }
56
57            let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
58            let values = array.values();
59            // @Q? Should this not pick fields[0]?
60            dictionaries_to_encode(field,
61                values.as_ref(),
62                dictionary_tracker,
63                dicts_to_encode,
64            )?;
65
66            Ok(())
67        }),
68        Struct => {
69            let array = array.as_any().downcast_ref::<StructArray>().unwrap();
70            let fields = field.fields.as_slice();
71            if array.fields().len() != fields.len() {
72                polars_bail!(InvalidOperation: "The number of fields in a struct must equal the number of children in IpcField");
73            }
74            fields
75                .iter()
76                .zip(array.values().iter())
77                .try_for_each(|(field, values)| {
78                    dictionaries_to_encode(
79                        field,
80                        values.as_ref(),
81                        dictionary_tracker,
82                        dicts_to_encode,
83                    )
84                })
85        },
86        List => {
87            let values = array
88                .as_any()
89                .downcast_ref::<ListArray<i32>>()
90                .unwrap()
91                .values();
92            let field = field.fields.first().ok_or_else(|| polars_err!(ComputeError: "Invalid IPC field structure: expected nested field but fields vector is empty"))?;
93            dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
94        },
95        LargeList => {
96            let values = array
97                .as_any()
98                .downcast_ref::<ListArray<i64>>()
99                .unwrap()
100                .values();
101            let field = field.fields.first().ok_or_else(|| polars_err!(ComputeError: "Invalid IPC field structure: expected nested field but fields vector is empty"))?;
102            dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
103        },
104        FixedSizeList => {
105            let values = array
106                .as_any()
107                .downcast_ref::<FixedSizeListArray>()
108                .unwrap()
109                .values();
110            let field = field.fields.first().ok_or_else(|| polars_err!(ComputeError: "Invalid IPC field structure: expected nested field but fields vector is empty"))?;
111            dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
112        },
113        Union => {
114            let values = array
115                .as_any()
116                .downcast_ref::<UnionArray>()
117                .unwrap()
118                .fields();
119            let fields = field.fields.as_slice();
120            if values.len() != fields.len() {
121                polars_bail!(InvalidOperation:
122                    "The number of fields in a union must equal the number of children in IpcField"
123                );
124            }
125            fields
126                .iter()
127                .zip(values.iter())
128                .try_for_each(|(field, values)| {
129                    dictionaries_to_encode(
130                        field,
131                        values.as_ref(),
132                        dictionary_tracker,
133                        dicts_to_encode,
134                    )
135                })
136        },
137        Map => {
138            let values = array.as_any().downcast_ref::<MapArray>().unwrap().field();
139            let field = field.fields.first().ok_or_else(|| polars_err!(ComputeError: "Invalid IPC field structure: expected nested field but fields vector is empty"))?;
140            dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
141        },
142    }
143}
144
145/// Encode a dictionary array with a certain id.
146///
147/// # Panics
148///
149/// This will panic if the given array is not a [`DictionaryArray`].
150pub fn encode_dictionary(
151    dict_id: i64,
152    array: &dyn Array,
153    options: &WriteOptions,
154) -> PolarsResult<EncodedData> {
155    let PhysicalType::Dictionary(key_type) = array.dtype().to_physical_type() else {
156        panic!("Given array is not a DictionaryArray")
157    };
158
159    match_integer_type!(key_type, |$T| {
160        let array: &DictionaryArray<$T> = array.as_any().downcast_ref().unwrap();
161
162        encode_dictionary_values(dict_id, array.values().as_ref(), options)
163    })
164}
165
166pub fn encode_new_dictionaries(
167    field: &IpcField,
168    array: &dyn Array,
169    options: &WriteOptions,
170    dictionary_tracker: &mut DictionaryTracker,
171    encoded_dictionaries: &mut Vec<EncodedData>,
172) -> PolarsResult<()> {
173    let mut dicts_to_encode = Vec::new();
174    dictionaries_to_encode(field, array, dictionary_tracker, &mut dicts_to_encode)?;
175    for (dict_id, dict_array) in dicts_to_encode {
176        encoded_dictionaries.push(encode_dictionary(dict_id, dict_array.as_ref(), options)?);
177    }
178    Ok(())
179}
180
181pub fn encode_chunk(
182    chunk: &RecordBatchT<Box<dyn Array>>,
183    fields: &[IpcField],
184    dictionary_tracker: &mut DictionaryTracker,
185    options: &WriteOptions,
186) -> PolarsResult<(Vec<EncodedData>, EncodedData)> {
187    let mut encoded_message = EncodedData::default();
188    let encoded_dictionaries = encode_chunk_amortized(
189        chunk,
190        fields,
191        dictionary_tracker,
192        options,
193        &mut encoded_message,
194    )?;
195    Ok((encoded_dictionaries, encoded_message))
196}
197
198// Amortizes `EncodedData` allocation.
199pub fn encode_chunk_amortized(
200    chunk: &RecordBatchT<Box<dyn Array>>,
201    fields: &[IpcField],
202    dictionary_tracker: &mut DictionaryTracker,
203    options: &WriteOptions,
204    encoded_message: &mut EncodedData,
205) -> PolarsResult<Vec<EncodedData>> {
206    let mut encoded_dictionaries = vec![];
207
208    for (field, array) in fields.iter().zip(chunk.as_ref()) {
209        encode_new_dictionaries(
210            field,
211            array.as_ref(),
212            options,
213            dictionary_tracker,
214            &mut encoded_dictionaries,
215        )?;
216    }
217    encode_record_batch(chunk, options, encoded_message);
218
219    Ok(encoded_dictionaries)
220}
221
222fn serialize_compression(
223    compression: Option<Compression>,
224) -> Option<Box<arrow_format::ipc::BodyCompression>> {
225    if let Some(compression) = compression {
226        let codec = match compression {
227            Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame,
228            Compression::ZSTD(_) => arrow_format::ipc::CompressionType::Zstd,
229        };
230        Some(Box::new(arrow_format::ipc::BodyCompression {
231            codec,
232            method: arrow_format::ipc::BodyCompressionMethod::Buffer,
233        }))
234    } else {
235        None
236    }
237}
238
239fn set_variadic_buffer_counts(counts: &mut Vec<i64>, array: &dyn Array) {
240    match array.dtype().to_storage() {
241        ArrowDataType::Utf8View => {
242            let array = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
243            counts.push(array.data_buffers().len() as i64);
244        },
245        ArrowDataType::BinaryView => {
246            let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
247            counts.push(array.data_buffers().len() as i64);
248        },
249        ArrowDataType::Struct(_) => {
250            let array = array.as_any().downcast_ref::<StructArray>().unwrap();
251            for array in array.values() {
252                set_variadic_buffer_counts(counts, array.as_ref())
253            }
254        },
255        ArrowDataType::LargeList(_) => {
256            // Subslicing can change the variadic buffer count, so we have to
257            // slice here as well to stay synchronized.
258            let array = array.as_any().downcast_ref::<LargeListArray>().unwrap();
259            let offsets = array.offsets().buffer();
260            let first = *offsets.first().unwrap();
261            let last = *offsets.last().unwrap();
262            let subslice = array
263                .values()
264                .sliced(first.to_usize(), last.to_usize() - first.to_usize());
265            set_variadic_buffer_counts(counts, &*subslice)
266        },
267        ArrowDataType::FixedSizeList(_, _) => {
268            let array = array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
269            set_variadic_buffer_counts(counts, array.values().as_ref())
270        },
271        // Don't traverse dictionary values as those are set when the `Dictionary` IPC struct
272        // is read.
273        ArrowDataType::Dictionary(_, _, _) => (),
274        _ => (),
275    }
276}
277
278fn gc_bin_view<'a, T: ViewType + ?Sized>(
279    arr: &'a Box<dyn Array>,
280    concrete_arr: &'a BinaryViewArrayGeneric<T>,
281) -> Cow<'a, Box<dyn Array>> {
282    let bytes_len = concrete_arr.total_bytes_len();
283    let buffer_len = concrete_arr.total_buffer_len();
284    let extra_len = buffer_len.saturating_sub(bytes_len);
285    if extra_len < bytes_len.min(1024) {
286        // We can afford some tiny waste.
287        Cow::Borrowed(arr)
288    } else {
289        // Force GC it.
290        Cow::Owned(concrete_arr.clone().gc().boxed())
291    }
292}
293
294pub fn encode_array(
295    array: &Box<dyn Array>,
296    options: &WriteOptions,
297    variadic_buffer_counts: &mut Vec<i64>,
298    buffers: &mut Vec<ipc::Buffer>,
299    arrow_data: &mut Vec<u8>,
300    nodes: &mut Vec<ipc::FieldNode>,
301    offset: &mut i64,
302) {
303    // We don't want to write all buffers in sliced arrays.
304    let array = match array.dtype() {
305        ArrowDataType::BinaryView => {
306            let concrete_arr = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
307            gc_bin_view(array, concrete_arr)
308        },
309        ArrowDataType::Utf8View => {
310            let concrete_arr = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
311            gc_bin_view(array, concrete_arr)
312        },
313        _ => Cow::Borrowed(array),
314    };
315    let array = array.as_ref().as_ref();
316
317    set_variadic_buffer_counts(variadic_buffer_counts, array);
318
319    write(
320        array,
321        buffers,
322        arrow_data,
323        nodes,
324        offset,
325        is_native_little_endian(),
326        options.compression,
327    )
328}
329
330/// Write [`RecordBatchT`] into two sets of bytes, one for the header (ipc::Schema::Message) and the
331/// other for the batch's data
332pub fn encode_record_batch(
333    chunk: &RecordBatchT<Box<dyn Array>>,
334    options: &WriteOptions,
335    encoded_message: &mut EncodedData,
336) {
337    let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
338    let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
339    encoded_message.arrow_data.clear();
340
341    let mut offset = 0;
342    let mut variadic_buffer_counts = vec![];
343    for array in chunk.arrays() {
344        encode_array(
345            array,
346            options,
347            &mut variadic_buffer_counts,
348            &mut buffers,
349            &mut encoded_message.arrow_data,
350            &mut nodes,
351            &mut offset,
352        );
353    }
354
355    commit_encoded_arrays(
356        chunk.len(),
357        options,
358        variadic_buffer_counts,
359        buffers,
360        nodes,
361        None,
362        encoded_message,
363    );
364}
365
366pub fn commit_encoded_arrays(
367    array_len: usize,
368    options: &WriteOptions,
369    variadic_buffer_counts: Vec<i64>,
370    buffers: Vec<ipc::Buffer>,
371    nodes: Vec<ipc::FieldNode>,
372    custom_metadata: Option<Vec<KeyValue>>,
373    encoded_message: &mut EncodedData,
374) {
375    let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
376        None
377    } else {
378        Some(variadic_buffer_counts)
379    };
380
381    let compression = serialize_compression(options.compression);
382
383    let message = arrow_format::ipc::Message {
384        version: arrow_format::ipc::MetadataVersion::V5,
385        header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new(
386            arrow_format::ipc::RecordBatch {
387                length: array_len as i64,
388                nodes: Some(nodes),
389                buffers: Some(buffers),
390                compression,
391                variadic_buffer_counts,
392            },
393        ))),
394        body_length: encoded_message.arrow_data.len() as i64,
395        custom_metadata,
396    };
397
398    let mut builder = Builder::new();
399    let ipc_message = builder.finish(&message, None);
400    encoded_message.ipc_message = ipc_message.to_vec();
401}
402
403pub fn encode_dictionary_values(
404    dict_id: i64,
405    values_array: &dyn Array,
406    options: &WriteOptions,
407) -> PolarsResult<EncodedData> {
408    let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
409    let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
410    let mut arrow_data: Vec<u8> = vec![];
411    let mut variadic_buffer_counts = vec![];
412    set_variadic_buffer_counts(&mut variadic_buffer_counts, values_array);
413
414    let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
415        None
416    } else {
417        Some(variadic_buffer_counts)
418    };
419
420    write(
421        values_array,
422        &mut buffers,
423        &mut arrow_data,
424        &mut nodes,
425        &mut 0,
426        is_native_little_endian(),
427        options.compression,
428    );
429
430    let compression = serialize_compression(options.compression);
431
432    let message = arrow_format::ipc::Message {
433        version: arrow_format::ipc::MetadataVersion::V5,
434        header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new(
435            arrow_format::ipc::DictionaryBatch {
436                id: dict_id,
437                data: Some(Box::new(arrow_format::ipc::RecordBatch {
438                    length: values_array.len() as i64,
439                    nodes: Some(nodes),
440                    buffers: Some(buffers),
441                    compression,
442                    variadic_buffer_counts,
443                })),
444                is_delta: false,
445            },
446        ))),
447        body_length: arrow_data.len() as i64,
448        custom_metadata: None,
449    };
450
451    let mut builder = Builder::new();
452    let ipc_message = builder.finish(&message, None);
453
454    Ok(EncodedData {
455        ipc_message: ipc_message.to_vec(),
456        arrow_data,
457    })
458}
459
460/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary
461/// multiple times. Can optionally error if an update to an existing dictionary is attempted, which
462/// isn't allowed in the `FileWriter`.
463pub struct DictionaryTracker {
464    pub dictionaries: Dictionaries,
465    pub cannot_replace: bool,
466}
467
468impl DictionaryTracker {
469    /// Keep track of the dictionary with the given ID and values. Behavior:
470    ///
471    /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate
472    ///   that the dictionary was not actually inserted (because it's already been seen).
473    /// * If this ID has been written already but with different data, and this tracker is
474    ///   configured to return an error, return an error.
475    /// * If the tracker has not been configured to error on replacement or this dictionary
476    ///   has never been seen before, return `Ok(true)` to indicate that the dictionary was just
477    ///   inserted.
478    pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> PolarsResult<bool> {
479        let values = match array.dtype().to_storage() {
480            ArrowDataType::Dictionary(key_type, _, _) => {
481                match_integer_type!(key_type, |$T| {
482                    let array = array
483                        .as_any()
484                        .downcast_ref::<DictionaryArray<$T>>()
485                        .unwrap();
486                    array.values()
487                })
488            },
489            _ => unreachable!(),
490        };
491
492        // If a dictionary with this id was already emitted, check if it was the same.
493        if let Some(last) = self.dictionaries.get(&dict_id) {
494            if last.as_ref() == values.as_ref() {
495                // Same dictionary values => no need to emit it again
496                return Ok(false);
497            } else if self.cannot_replace {
498                polars_bail!(InvalidOperation:
499                    "Dictionary replacement detected when writing IPC file format. \
500                     Arrow IPC files only support a single dictionary for a given field \
501                     across all batches."
502                );
503            }
504        };
505
506        self.dictionaries.insert(dict_id, values.clone());
507        Ok(true)
508    }
509}
510
511/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data
512#[derive(Debug, Default)]
513pub struct EncodedData {
514    /// An encoded ipc::Schema::Message
515    pub ipc_message: Vec<u8>,
516    /// Arrow buffers to be written, should be an empty vec for schema messages
517    pub arrow_data: Vec<u8>,
518}
519
520/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data
521#[derive(Debug, Default)]
522pub struct EncodedDataBytes {
523    /// An encoded ipc::Schema::Message
524    pub ipc_message: Bytes,
525    /// Arrow buffers to be written, should be an empty vec for schema messages
526    pub arrow_data: Bytes,
527}
528
529/// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes
530#[inline]
531pub(crate) fn pad_to_64(len: usize) -> usize {
532    ((len + 63) & !63) - len
533}
534
535/// An array [`RecordBatchT`] with optional accompanying IPC fields.
536#[derive(Debug, Clone, PartialEq)]
537pub struct Record<'a> {
538    columns: Cow<'a, RecordBatchT<Box<dyn Array>>>,
539    fields: Option<Cow<'a, [IpcField]>>,
540}
541
542impl Record<'_> {
543    /// Get the IPC fields for this record.
544    pub fn fields(&self) -> Option<&[IpcField]> {
545        self.fields.as_deref()
546    }
547
548    /// Get the Arrow columns in this record.
549    pub fn columns(&self) -> &RecordBatchT<Box<dyn Array>> {
550        self.columns.borrow()
551    }
552}
553
554impl From<RecordBatchT<Box<dyn Array>>> for Record<'static> {
555    fn from(columns: RecordBatchT<Box<dyn Array>>) -> Self {
556        Self {
557            columns: Cow::Owned(columns),
558            fields: None,
559        }
560    }
561}
562
563impl<'a, F> From<(RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
564where
565    F: Into<Cow<'a, [IpcField]>>,
566{
567    fn from((columns, fields): (RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
568        Self {
569            columns: Cow::Owned(columns),
570            fields: fields.map(|f| f.into()),
571        }
572    }
573}
574
575impl<'a, F> From<(&'a RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
576where
577    F: Into<Cow<'a, [IpcField]>>,
578{
579    fn from((columns, fields): (&'a RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
580        Self {
581            columns: Cow::Borrowed(columns),
582            fields: fields.map(|f| f.into()),
583        }
584    }
585}
586
587/// Create an IPC Block. Will panic when size limitations are not met.
588pub fn arrow_ipc_block(
589    offset: usize,
590    meta_data_length: usize,
591    body_length: usize,
592) -> arrow_format::ipc::Block {
593    arrow_format::ipc::Block {
594        offset: i64::try_from(offset).unwrap(),
595        meta_data_length: i32::try_from(meta_data_length).unwrap(),
596        body_length: i64::try_from(body_length).unwrap(),
597    }
598}