polars_arrow/io/ipc/write/
common.rs

1use std::borrow::{Borrow, Cow};
2
3use arrow_format::ipc;
4use arrow_format::ipc::planus::Builder;
5use polars_error::{PolarsResult, polars_bail, polars_err};
6use polars_utils::compression::ZstdLevel;
7
8use super::super::IpcField;
9use super::write;
10use crate::array::*;
11use crate::datatypes::*;
12use crate::io::ipc::endianness::is_native_little_endian;
13use crate::io::ipc::read::Dictionaries;
14use crate::legacy::prelude::LargeListArray;
15use crate::match_integer_type;
16use crate::record_batch::RecordBatchT;
17use crate::types::Index;
18
19/// Compression codec
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum Compression {
22    /// LZ4 (framed)
23    LZ4,
24    /// ZSTD
25    ZSTD(ZstdLevel),
26}
27
28/// Options declaring the behaviour of writing to IPC
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
30pub struct WriteOptions {
31    /// Whether the buffers should be compressed and which codec to use.
32    /// Note: to use compression the crate must be compiled with feature `io_ipc_compression`.
33    pub compression: Option<Compression>,
34}
35
36/// Find the dictionary that are new and need to be encoded.
37pub fn dictionaries_to_encode(
38    field: &IpcField,
39    array: &dyn Array,
40    dictionary_tracker: &mut DictionaryTracker,
41    dicts_to_encode: &mut Vec<(i64, Box<dyn Array>)>,
42) -> PolarsResult<()> {
43    use PhysicalType::*;
44    match array.dtype().to_physical_type() {
45        Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null
46        | FixedSizeBinary | BinaryView | Utf8View => Ok(()),
47        Dictionary(key_type) => match_integer_type!(key_type, |$T| {
48            let dict_id = field.dictionary_id
49                .ok_or_else(|| polars_err!(InvalidOperation: "Dictionaries must have an associated id"))?;
50
51            if dictionary_tracker.insert(dict_id, array)? {
52                dicts_to_encode.push((dict_id, array.to_boxed()));
53            }
54
55            let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
56            let values = array.values();
57            // @Q? Should this not pick fields[0]?
58            dictionaries_to_encode(field,
59                values.as_ref(),
60                dictionary_tracker,
61                dicts_to_encode,
62            )?;
63
64            Ok(())
65        }),
66        Struct => {
67            let array = array.as_any().downcast_ref::<StructArray>().unwrap();
68            let fields = field.fields.as_slice();
69            if array.fields().len() != fields.len() {
70                polars_bail!(InvalidOperation:
71                    "The number of fields in a struct must equal the number of children in IpcField".to_string(),
72                );
73            }
74            fields
75                .iter()
76                .zip(array.values().iter())
77                .try_for_each(|(field, values)| {
78                    dictionaries_to_encode(
79                        field,
80                        values.as_ref(),
81                        dictionary_tracker,
82                        dicts_to_encode,
83                    )
84                })
85        },
86        List => {
87            let values = array
88                .as_any()
89                .downcast_ref::<ListArray<i32>>()
90                .unwrap()
91                .values();
92            let field = &field.fields[0]; // todo: error instead
93            dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
94        },
95        LargeList => {
96            let values = array
97                .as_any()
98                .downcast_ref::<ListArray<i64>>()
99                .unwrap()
100                .values();
101            let field = &field.fields[0]; // todo: error instead
102            dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
103        },
104        FixedSizeList => {
105            let values = array
106                .as_any()
107                .downcast_ref::<FixedSizeListArray>()
108                .unwrap()
109                .values();
110            let field = &field.fields[0]; // todo: error instead
111            dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
112        },
113        Union => {
114            let values = array
115                .as_any()
116                .downcast_ref::<UnionArray>()
117                .unwrap()
118                .fields();
119            let fields = &field.fields[..]; // todo: error instead
120            if values.len() != fields.len() {
121                polars_bail!(InvalidOperation:
122                    "The number of fields in a union must equal the number of children in IpcField"
123                );
124            }
125            fields
126                .iter()
127                .zip(values.iter())
128                .try_for_each(|(field, values)| {
129                    dictionaries_to_encode(
130                        field,
131                        values.as_ref(),
132                        dictionary_tracker,
133                        dicts_to_encode,
134                    )
135                })
136        },
137        Map => {
138            let values = array.as_any().downcast_ref::<MapArray>().unwrap().field();
139            let field = &field.fields[0]; // todo: error instead
140            dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
141        },
142    }
143}
144
145/// Encode a dictionary array with a certain id.
146///
147/// # Panics
148///
149/// This will panic if the given array is not a [`DictionaryArray`].
150pub fn encode_dictionary(
151    dict_id: i64,
152    array: &dyn Array,
153    options: &WriteOptions,
154) -> PolarsResult<EncodedData> {
155    let PhysicalType::Dictionary(key_type) = array.dtype().to_physical_type() else {
156        panic!("Given array is not a DictionaryArray")
157    };
158
159    match_integer_type!(key_type, |$T| {
160        let array: &DictionaryArray<$T> = array.as_any().downcast_ref().unwrap();
161
162        encode_dictionary_values(dict_id, array.values().as_ref(), options)
163    })
164}
165
166pub fn encode_new_dictionaries(
167    field: &IpcField,
168    array: &dyn Array,
169    options: &WriteOptions,
170    dictionary_tracker: &mut DictionaryTracker,
171    encoded_dictionaries: &mut Vec<EncodedData>,
172) -> PolarsResult<()> {
173    let mut dicts_to_encode = Vec::new();
174    dictionaries_to_encode(field, array, dictionary_tracker, &mut dicts_to_encode)?;
175    for (dict_id, dict_array) in dicts_to_encode {
176        encoded_dictionaries.push(encode_dictionary(dict_id, dict_array.as_ref(), options)?);
177    }
178    Ok(())
179}
180
181pub fn encode_chunk(
182    chunk: &RecordBatchT<Box<dyn Array>>,
183    fields: &[IpcField],
184    dictionary_tracker: &mut DictionaryTracker,
185    options: &WriteOptions,
186) -> PolarsResult<(Vec<EncodedData>, EncodedData)> {
187    let mut encoded_message = EncodedData::default();
188    let encoded_dictionaries = encode_chunk_amortized(
189        chunk,
190        fields,
191        dictionary_tracker,
192        options,
193        &mut encoded_message,
194    )?;
195    Ok((encoded_dictionaries, encoded_message))
196}
197
198// Amortizes `EncodedData` allocation.
199pub fn encode_chunk_amortized(
200    chunk: &RecordBatchT<Box<dyn Array>>,
201    fields: &[IpcField],
202    dictionary_tracker: &mut DictionaryTracker,
203    options: &WriteOptions,
204    encoded_message: &mut EncodedData,
205) -> PolarsResult<Vec<EncodedData>> {
206    let mut encoded_dictionaries = vec![];
207
208    for (field, array) in fields.iter().zip(chunk.as_ref()) {
209        encode_new_dictionaries(
210            field,
211            array.as_ref(),
212            options,
213            dictionary_tracker,
214            &mut encoded_dictionaries,
215        )?;
216    }
217    encode_record_batch(chunk, options, encoded_message);
218
219    Ok(encoded_dictionaries)
220}
221
222fn serialize_compression(
223    compression: Option<Compression>,
224) -> Option<Box<arrow_format::ipc::BodyCompression>> {
225    if let Some(compression) = compression {
226        let codec = match compression {
227            Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame,
228            Compression::ZSTD(_) => arrow_format::ipc::CompressionType::Zstd,
229        };
230        Some(Box::new(arrow_format::ipc::BodyCompression {
231            codec,
232            method: arrow_format::ipc::BodyCompressionMethod::Buffer,
233        }))
234    } else {
235        None
236    }
237}
238
239fn set_variadic_buffer_counts(counts: &mut Vec<i64>, array: &dyn Array) {
240    match array.dtype() {
241        ArrowDataType::Utf8View => {
242            let array = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
243            counts.push(array.data_buffers().len() as i64);
244        },
245        ArrowDataType::BinaryView => {
246            let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
247            counts.push(array.data_buffers().len() as i64);
248        },
249        ArrowDataType::Struct(_) => {
250            let array = array.as_any().downcast_ref::<StructArray>().unwrap();
251            for array in array.values() {
252                set_variadic_buffer_counts(counts, array.as_ref())
253            }
254        },
255        ArrowDataType::LargeList(_) => {
256            // Subslicing can change the variadic buffer count, so we have to
257            // slice here as well to stay synchronized.
258            let array = array.as_any().downcast_ref::<LargeListArray>().unwrap();
259            let offsets = array.offsets().buffer();
260            let first = *offsets.first().unwrap();
261            let last = *offsets.last().unwrap();
262            let subslice = array
263                .values()
264                .sliced(first.to_usize(), last.to_usize() - first.to_usize());
265            set_variadic_buffer_counts(counts, &*subslice)
266        },
267        ArrowDataType::FixedSizeList(_, _) => {
268            let array = array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
269            set_variadic_buffer_counts(counts, array.values().as_ref())
270        },
271        // Don't traverse dictionary values as those are set when the `Dictionary` IPC struct
272        // is read.
273        ArrowDataType::Dictionary(_, _, _) => (),
274        _ => (),
275    }
276}
277
278fn gc_bin_view<'a, T: ViewType + ?Sized>(
279    arr: &'a Box<dyn Array>,
280    concrete_arr: &'a BinaryViewArrayGeneric<T>,
281) -> Cow<'a, Box<dyn Array>> {
282    let bytes_len = concrete_arr.total_bytes_len();
283    let buffer_len = concrete_arr.total_buffer_len();
284    let extra_len = buffer_len.saturating_sub(bytes_len);
285    if extra_len < bytes_len.min(1024) {
286        // We can afford some tiny waste.
287        Cow::Borrowed(arr)
288    } else {
289        // Force GC it.
290        Cow::Owned(concrete_arr.clone().gc().boxed())
291    }
292}
293
294pub fn encode_array(
295    array: &Box<dyn Array>,
296    options: &WriteOptions,
297    variadic_buffer_counts: &mut Vec<i64>,
298    buffers: &mut Vec<ipc::Buffer>,
299    arrow_data: &mut Vec<u8>,
300    nodes: &mut Vec<ipc::FieldNode>,
301    offset: &mut i64,
302) {
303    // We don't want to write all buffers in sliced arrays.
304    let array = match array.dtype() {
305        ArrowDataType::BinaryView => {
306            let concrete_arr = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
307            gc_bin_view(array, concrete_arr)
308        },
309        ArrowDataType::Utf8View => {
310            let concrete_arr = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
311            gc_bin_view(array, concrete_arr)
312        },
313        _ => Cow::Borrowed(array),
314    };
315    let array = array.as_ref().as_ref();
316
317    set_variadic_buffer_counts(variadic_buffer_counts, array);
318
319    write(
320        array,
321        buffers,
322        arrow_data,
323        nodes,
324        offset,
325        is_native_little_endian(),
326        options.compression,
327    )
328}
329
330/// Write [`RecordBatchT`] into two sets of bytes, one for the header (ipc::Schema::Message) and the
331/// other for the batch's data
332pub fn encode_record_batch(
333    chunk: &RecordBatchT<Box<dyn Array>>,
334    options: &WriteOptions,
335    encoded_message: &mut EncodedData,
336) {
337    let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
338    let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
339    encoded_message.arrow_data.clear();
340
341    let mut offset = 0;
342    let mut variadic_buffer_counts = vec![];
343    for array in chunk.arrays() {
344        encode_array(
345            array,
346            options,
347            &mut variadic_buffer_counts,
348            &mut buffers,
349            &mut encoded_message.arrow_data,
350            &mut nodes,
351            &mut offset,
352        );
353    }
354
355    commit_encoded_arrays(
356        chunk.len(),
357        options,
358        variadic_buffer_counts,
359        buffers,
360        nodes,
361        encoded_message,
362    );
363}
364
365pub fn commit_encoded_arrays(
366    array_len: usize,
367    options: &WriteOptions,
368    variadic_buffer_counts: Vec<i64>,
369    buffers: Vec<ipc::Buffer>,
370    nodes: Vec<ipc::FieldNode>,
371    encoded_message: &mut EncodedData,
372) {
373    let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
374        None
375    } else {
376        Some(variadic_buffer_counts)
377    };
378
379    let compression = serialize_compression(options.compression);
380
381    let message = arrow_format::ipc::Message {
382        version: arrow_format::ipc::MetadataVersion::V5,
383        header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new(
384            arrow_format::ipc::RecordBatch {
385                length: array_len as i64,
386                nodes: Some(nodes),
387                buffers: Some(buffers),
388                compression,
389                variadic_buffer_counts,
390            },
391        ))),
392        body_length: encoded_message.arrow_data.len() as i64,
393        custom_metadata: None,
394    };
395
396    let mut builder = Builder::new();
397    let ipc_message = builder.finish(&message, None);
398    encoded_message.ipc_message = ipc_message.to_vec();
399}
400
401pub fn encode_dictionary_values(
402    dict_id: i64,
403    values_array: &dyn Array,
404    options: &WriteOptions,
405) -> PolarsResult<EncodedData> {
406    let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
407    let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
408    let mut arrow_data: Vec<u8> = vec![];
409    let mut variadic_buffer_counts = vec![];
410    set_variadic_buffer_counts(&mut variadic_buffer_counts, values_array);
411
412    let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
413        None
414    } else {
415        Some(variadic_buffer_counts)
416    };
417
418    write(
419        values_array,
420        &mut buffers,
421        &mut arrow_data,
422        &mut nodes,
423        &mut 0,
424        is_native_little_endian(),
425        options.compression,
426    );
427
428    let compression = serialize_compression(options.compression);
429
430    let message = arrow_format::ipc::Message {
431        version: arrow_format::ipc::MetadataVersion::V5,
432        header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new(
433            arrow_format::ipc::DictionaryBatch {
434                id: dict_id,
435                data: Some(Box::new(arrow_format::ipc::RecordBatch {
436                    length: values_array.len() as i64,
437                    nodes: Some(nodes),
438                    buffers: Some(buffers),
439                    compression,
440                    variadic_buffer_counts,
441                })),
442                is_delta: false,
443            },
444        ))),
445        body_length: arrow_data.len() as i64,
446        custom_metadata: None,
447    };
448
449    let mut builder = Builder::new();
450    let ipc_message = builder.finish(&message, None);
451
452    Ok(EncodedData {
453        ipc_message: ipc_message.to_vec(),
454        arrow_data,
455    })
456}
457
458/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary
459/// multiple times. Can optionally error if an update to an existing dictionary is attempted, which
460/// isn't allowed in the `FileWriter`.
461pub struct DictionaryTracker {
462    pub dictionaries: Dictionaries,
463    pub cannot_replace: bool,
464}
465
466impl DictionaryTracker {
467    /// Keep track of the dictionary with the given ID and values. Behavior:
468    ///
469    /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate
470    ///   that the dictionary was not actually inserted (because it's already been seen).
471    /// * If this ID has been written already but with different data, and this tracker is
472    ///   configured to return an error, return an error.
473    /// * If the tracker has not been configured to error on replacement or this dictionary
474    ///   has never been seen before, return `Ok(true)` to indicate that the dictionary was just
475    ///   inserted.
476    pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> PolarsResult<bool> {
477        let values = match array.dtype() {
478            ArrowDataType::Dictionary(key_type, _, _) => {
479                match_integer_type!(key_type, |$T| {
480                    let array = array
481                        .as_any()
482                        .downcast_ref::<DictionaryArray<$T>>()
483                        .unwrap();
484                    array.values()
485                })
486            },
487            _ => unreachable!(),
488        };
489
490        // If a dictionary with this id was already emitted, check if it was the same.
491        if let Some(last) = self.dictionaries.get(&dict_id) {
492            if last.as_ref() == values.as_ref() {
493                // Same dictionary values => no need to emit it again
494                return Ok(false);
495            } else if self.cannot_replace {
496                polars_bail!(InvalidOperation:
497                    "Dictionary replacement detected when writing IPC file format. \
498                     Arrow IPC files only support a single dictionary for a given field \
499                     across all batches."
500                );
501            }
502        };
503
504        self.dictionaries.insert(dict_id, values.clone());
505        Ok(true)
506    }
507}
508
509/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data
510#[derive(Debug, Default)]
511pub struct EncodedData {
512    /// An encoded ipc::Schema::Message
513    pub ipc_message: Vec<u8>,
514    /// Arrow buffers to be written, should be an empty vec for schema messages
515    pub arrow_data: Vec<u8>,
516}
517
518/// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes
519#[inline]
520pub(crate) fn pad_to_64(len: usize) -> usize {
521    ((len + 63) & !63) - len
522}
523
524/// An array [`RecordBatchT`] with optional accompanying IPC fields.
525#[derive(Debug, Clone, PartialEq)]
526pub struct Record<'a> {
527    columns: Cow<'a, RecordBatchT<Box<dyn Array>>>,
528    fields: Option<Cow<'a, [IpcField]>>,
529}
530
531impl Record<'_> {
532    /// Get the IPC fields for this record.
533    pub fn fields(&self) -> Option<&[IpcField]> {
534        self.fields.as_deref()
535    }
536
537    /// Get the Arrow columns in this record.
538    pub fn columns(&self) -> &RecordBatchT<Box<dyn Array>> {
539        self.columns.borrow()
540    }
541}
542
543impl From<RecordBatchT<Box<dyn Array>>> for Record<'static> {
544    fn from(columns: RecordBatchT<Box<dyn Array>>) -> Self {
545        Self {
546            columns: Cow::Owned(columns),
547            fields: None,
548        }
549    }
550}
551
552impl<'a, F> From<(RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
553where
554    F: Into<Cow<'a, [IpcField]>>,
555{
556    fn from((columns, fields): (RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
557        Self {
558            columns: Cow::Owned(columns),
559            fields: fields.map(|f| f.into()),
560        }
561    }
562}
563
564impl<'a, F> From<(&'a RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
565where
566    F: Into<Cow<'a, [IpcField]>>,
567{
568    fn from((columns, fields): (&'a RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
569        Self {
570            columns: Cow::Borrowed(columns),
571            fields: fields.map(|f| f.into()),
572        }
573    }
574}