polars_arrow/io/ipc/write/
common.rs

1use std::borrow::{Borrow, Cow};
2
3use arrow_format::ipc;
4use arrow_format::ipc::planus::Builder;
5use polars_error::{PolarsResult, polars_bail, polars_err};
6
7use super::super::IpcField;
8use super::{write, write_dictionary};
9use crate::array::*;
10use crate::datatypes::*;
11use crate::io::ipc::endianness::is_native_little_endian;
12use crate::io::ipc::read::Dictionaries;
13use crate::legacy::prelude::LargeListArray;
14use crate::match_integer_type;
15use crate::record_batch::RecordBatchT;
16use crate::types::Index;
17
18/// Compression codec
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum Compression {
21    /// LZ4 (framed)
22    LZ4,
23    /// ZSTD
24    ZSTD,
25}
26
27/// Options declaring the behaviour of writing to IPC
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
29pub struct WriteOptions {
30    /// Whether the buffers should be compressed and which codec to use.
31    /// Note: to use compression the crate must be compiled with feature `io_ipc_compression`.
32    pub compression: Option<Compression>,
33}
34
35/// Find the dictionary that are new and need to be encoded.
36pub fn dictionaries_to_encode(
37    field: &IpcField,
38    array: &dyn Array,
39    dictionary_tracker: &mut DictionaryTracker,
40    dicts_to_encode: &mut Vec<(i64, Box<dyn Array>)>,
41) -> PolarsResult<()> {
42    use PhysicalType::*;
43    match array.dtype().to_physical_type() {
44        Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null
45        | FixedSizeBinary | BinaryView | Utf8View => Ok(()),
46        Dictionary(key_type) => match_integer_type!(key_type, |$T| {
47            let dict_id = field.dictionary_id
48                .ok_or_else(|| polars_err!(InvalidOperation: "Dictionaries must have an associated id"))?;
49
50            if dictionary_tracker.insert(dict_id, array)? {
51                dicts_to_encode.push((dict_id, array.to_boxed()));
52            }
53
54            let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
55            let values = array.values();
56            // @Q? Should this not pick fields[0]?
57            dictionaries_to_encode(field,
58                values.as_ref(),
59                dictionary_tracker,
60                dicts_to_encode,
61            )?;
62
63            Ok(())
64        }),
65        Struct => {
66            let array = array.as_any().downcast_ref::<StructArray>().unwrap();
67            let fields = field.fields.as_slice();
68            if array.fields().len() != fields.len() {
69                polars_bail!(InvalidOperation:
70                    "The number of fields in a struct must equal the number of children in IpcField".to_string(),
71                );
72            }
73            fields
74                .iter()
75                .zip(array.values().iter())
76                .try_for_each(|(field, values)| {
77                    dictionaries_to_encode(
78                        field,
79                        values.as_ref(),
80                        dictionary_tracker,
81                        dicts_to_encode,
82                    )
83                })
84        },
85        List => {
86            let values = array
87                .as_any()
88                .downcast_ref::<ListArray<i32>>()
89                .unwrap()
90                .values();
91            let field = &field.fields[0]; // todo: error instead
92            dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
93        },
94        LargeList => {
95            let values = array
96                .as_any()
97                .downcast_ref::<ListArray<i64>>()
98                .unwrap()
99                .values();
100            let field = &field.fields[0]; // todo: error instead
101            dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
102        },
103        FixedSizeList => {
104            let values = array
105                .as_any()
106                .downcast_ref::<FixedSizeListArray>()
107                .unwrap()
108                .values();
109            let field = &field.fields[0]; // todo: error instead
110            dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
111        },
112        Union => {
113            let values = array
114                .as_any()
115                .downcast_ref::<UnionArray>()
116                .unwrap()
117                .fields();
118            let fields = &field.fields[..]; // todo: error instead
119            if values.len() != fields.len() {
120                polars_bail!(InvalidOperation:
121                    "The number of fields in a union must equal the number of children in IpcField"
122                );
123            }
124            fields
125                .iter()
126                .zip(values.iter())
127                .try_for_each(|(field, values)| {
128                    dictionaries_to_encode(
129                        field,
130                        values.as_ref(),
131                        dictionary_tracker,
132                        dicts_to_encode,
133                    )
134                })
135        },
136        Map => {
137            let values = array.as_any().downcast_ref::<MapArray>().unwrap().field();
138            let field = &field.fields[0]; // todo: error instead
139            dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
140        },
141    }
142}
143
144/// Encode a dictionary array with a certain id.
145///
146/// # Panics
147///
148/// This will panic if the given array is not a [`DictionaryArray`].
149pub fn encode_dictionary(
150    dict_id: i64,
151    array: &dyn Array,
152    options: &WriteOptions,
153    encoded_dictionaries: &mut Vec<EncodedData>,
154) -> PolarsResult<()> {
155    let PhysicalType::Dictionary(key_type) = array.dtype().to_physical_type() else {
156        panic!("Given array is not a DictionaryArray")
157    };
158
159    match_integer_type!(key_type, |$T| {
160        let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
161        encoded_dictionaries.push(dictionary_batch_to_bytes::<$T>(
162            dict_id,
163            array,
164            options,
165            is_native_little_endian(),
166        ));
167    });
168
169    Ok(())
170}
171
172pub fn encode_new_dictionaries(
173    field: &IpcField,
174    array: &dyn Array,
175    options: &WriteOptions,
176    dictionary_tracker: &mut DictionaryTracker,
177    encoded_dictionaries: &mut Vec<EncodedData>,
178) -> PolarsResult<()> {
179    let mut dicts_to_encode = Vec::new();
180    dictionaries_to_encode(field, array, dictionary_tracker, &mut dicts_to_encode)?;
181    for (dict_id, dict_array) in dicts_to_encode {
182        encode_dictionary(dict_id, dict_array.as_ref(), options, encoded_dictionaries)?;
183    }
184    Ok(())
185}
186
187pub fn encode_chunk(
188    chunk: &RecordBatchT<Box<dyn Array>>,
189    fields: &[IpcField],
190    dictionary_tracker: &mut DictionaryTracker,
191    options: &WriteOptions,
192) -> PolarsResult<(Vec<EncodedData>, EncodedData)> {
193    let mut encoded_message = EncodedData::default();
194    let encoded_dictionaries = encode_chunk_amortized(
195        chunk,
196        fields,
197        dictionary_tracker,
198        options,
199        &mut encoded_message,
200    )?;
201    Ok((encoded_dictionaries, encoded_message))
202}
203
204// Amortizes `EncodedData` allocation.
205pub fn encode_chunk_amortized(
206    chunk: &RecordBatchT<Box<dyn Array>>,
207    fields: &[IpcField],
208    dictionary_tracker: &mut DictionaryTracker,
209    options: &WriteOptions,
210    encoded_message: &mut EncodedData,
211) -> PolarsResult<Vec<EncodedData>> {
212    let mut encoded_dictionaries = vec![];
213
214    for (field, array) in fields.iter().zip(chunk.as_ref()) {
215        encode_new_dictionaries(
216            field,
217            array.as_ref(),
218            options,
219            dictionary_tracker,
220            &mut encoded_dictionaries,
221        )?;
222    }
223    encode_record_batch(chunk, options, encoded_message);
224
225    Ok(encoded_dictionaries)
226}
227
228fn serialize_compression(
229    compression: Option<Compression>,
230) -> Option<Box<arrow_format::ipc::BodyCompression>> {
231    if let Some(compression) = compression {
232        let codec = match compression {
233            Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame,
234            Compression::ZSTD => arrow_format::ipc::CompressionType::Zstd,
235        };
236        Some(Box::new(arrow_format::ipc::BodyCompression {
237            codec,
238            method: arrow_format::ipc::BodyCompressionMethod::Buffer,
239        }))
240    } else {
241        None
242    }
243}
244
245fn set_variadic_buffer_counts(counts: &mut Vec<i64>, array: &dyn Array) {
246    match array.dtype() {
247        ArrowDataType::Utf8View => {
248            let array = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
249            counts.push(array.data_buffers().len() as i64);
250        },
251        ArrowDataType::BinaryView => {
252            let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
253            counts.push(array.data_buffers().len() as i64);
254        },
255        ArrowDataType::Struct(_) => {
256            let array = array.as_any().downcast_ref::<StructArray>().unwrap();
257            for array in array.values() {
258                set_variadic_buffer_counts(counts, array.as_ref())
259            }
260        },
261        ArrowDataType::LargeList(_) => {
262            // Subslicing can change the variadic buffer count, so we have to
263            // slice here as well to stay synchronized.
264            let array = array.as_any().downcast_ref::<LargeListArray>().unwrap();
265            let offsets = array.offsets().buffer();
266            let first = *offsets.first().unwrap();
267            let last = *offsets.last().unwrap();
268            let subslice = array
269                .values()
270                .sliced(first.to_usize(), last.to_usize() - first.to_usize());
271            set_variadic_buffer_counts(counts, &*subslice)
272        },
273        ArrowDataType::FixedSizeList(_, _) => {
274            let array = array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
275            set_variadic_buffer_counts(counts, array.values().as_ref())
276        },
277        // Don't traverse dictionary values as those are set when the `Dictionary` IPC struct
278        // is read.
279        ArrowDataType::Dictionary(_, _, _) => (),
280        _ => (),
281    }
282}
283
284fn gc_bin_view<'a, T: ViewType + ?Sized>(
285    arr: &'a Box<dyn Array>,
286    concrete_arr: &'a BinaryViewArrayGeneric<T>,
287) -> Cow<'a, Box<dyn Array>> {
288    let bytes_len = concrete_arr.total_bytes_len();
289    let buffer_len = concrete_arr.total_buffer_len();
290    let extra_len = buffer_len.saturating_sub(bytes_len);
291    if extra_len < bytes_len.min(1024) {
292        // We can afford some tiny waste.
293        Cow::Borrowed(arr)
294    } else {
295        // Force GC it.
296        Cow::Owned(concrete_arr.clone().gc().boxed())
297    }
298}
299
300pub fn encode_array(
301    array: &Box<dyn Array>,
302    options: &WriteOptions,
303    variadic_buffer_counts: &mut Vec<i64>,
304    buffers: &mut Vec<ipc::Buffer>,
305    arrow_data: &mut Vec<u8>,
306    nodes: &mut Vec<ipc::FieldNode>,
307    offset: &mut i64,
308) {
309    // We don't want to write all buffers in sliced arrays.
310    let array = match array.dtype() {
311        ArrowDataType::BinaryView => {
312            let concrete_arr = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
313            gc_bin_view(array, concrete_arr)
314        },
315        ArrowDataType::Utf8View => {
316            let concrete_arr = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
317            gc_bin_view(array, concrete_arr)
318        },
319        _ => Cow::Borrowed(array),
320    };
321    let array = array.as_ref().as_ref();
322
323    set_variadic_buffer_counts(variadic_buffer_counts, array);
324
325    write(
326        array,
327        buffers,
328        arrow_data,
329        nodes,
330        offset,
331        is_native_little_endian(),
332        options.compression,
333    )
334}
335
336/// Write [`RecordBatchT`] into two sets of bytes, one for the header (ipc::Schema::Message) and the
337/// other for the batch's data
338pub fn encode_record_batch(
339    chunk: &RecordBatchT<Box<dyn Array>>,
340    options: &WriteOptions,
341    encoded_message: &mut EncodedData,
342) {
343    let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
344    let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
345    encoded_message.arrow_data.clear();
346
347    let mut offset = 0;
348    let mut variadic_buffer_counts = vec![];
349    for array in chunk.arrays() {
350        encode_array(
351            array,
352            options,
353            &mut variadic_buffer_counts,
354            &mut buffers,
355            &mut encoded_message.arrow_data,
356            &mut nodes,
357            &mut offset,
358        );
359    }
360
361    commit_encoded_arrays(
362        chunk.len(),
363        options,
364        variadic_buffer_counts,
365        buffers,
366        nodes,
367        encoded_message,
368    );
369}
370
371pub fn commit_encoded_arrays(
372    array_len: usize,
373    options: &WriteOptions,
374    variadic_buffer_counts: Vec<i64>,
375    buffers: Vec<ipc::Buffer>,
376    nodes: Vec<ipc::FieldNode>,
377    encoded_message: &mut EncodedData,
378) {
379    let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
380        None
381    } else {
382        Some(variadic_buffer_counts)
383    };
384
385    let compression = serialize_compression(options.compression);
386
387    let message = arrow_format::ipc::Message {
388        version: arrow_format::ipc::MetadataVersion::V5,
389        header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new(
390            arrow_format::ipc::RecordBatch {
391                length: array_len as i64,
392                nodes: Some(nodes),
393                buffers: Some(buffers),
394                compression,
395                variadic_buffer_counts,
396            },
397        ))),
398        body_length: encoded_message.arrow_data.len() as i64,
399        custom_metadata: None,
400    };
401
402    let mut builder = Builder::new();
403    let ipc_message = builder.finish(&message, None);
404    encoded_message.ipc_message = ipc_message.to_vec();
405}
406
407/// Write dictionary values into two sets of bytes, one for the header (ipc::Schema::Message) and the
408/// other for the data
409fn dictionary_batch_to_bytes<K: DictionaryKey>(
410    dict_id: i64,
411    array: &DictionaryArray<K>,
412    options: &WriteOptions,
413    is_little_endian: bool,
414) -> EncodedData {
415    let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
416    let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
417    let mut arrow_data: Vec<u8> = vec![];
418    let mut variadic_buffer_counts = vec![];
419    set_variadic_buffer_counts(&mut variadic_buffer_counts, array.values().as_ref());
420
421    let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
422        None
423    } else {
424        Some(variadic_buffer_counts)
425    };
426
427    let length = write_dictionary(
428        array,
429        &mut buffers,
430        &mut arrow_data,
431        &mut nodes,
432        &mut 0,
433        is_little_endian,
434        options.compression,
435        false,
436    );
437
438    let compression = serialize_compression(options.compression);
439
440    let message = arrow_format::ipc::Message {
441        version: arrow_format::ipc::MetadataVersion::V5,
442        header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new(
443            arrow_format::ipc::DictionaryBatch {
444                id: dict_id,
445                data: Some(Box::new(arrow_format::ipc::RecordBatch {
446                    length: length as i64,
447                    nodes: Some(nodes),
448                    buffers: Some(buffers),
449                    compression,
450                    variadic_buffer_counts,
451                })),
452                is_delta: false,
453            },
454        ))),
455        body_length: arrow_data.len() as i64,
456        custom_metadata: None,
457    };
458
459    let mut builder = Builder::new();
460    let ipc_message = builder.finish(&message, None);
461
462    EncodedData {
463        ipc_message: ipc_message.to_vec(),
464        arrow_data,
465    }
466}
467
468/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary
469/// multiple times. Can optionally error if an update to an existing dictionary is attempted, which
470/// isn't allowed in the `FileWriter`.
471pub struct DictionaryTracker {
472    pub dictionaries: Dictionaries,
473    pub cannot_replace: bool,
474}
475
476impl DictionaryTracker {
477    /// Keep track of the dictionary with the given ID and values. Behavior:
478    ///
479    /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate
480    ///   that the dictionary was not actually inserted (because it's already been seen).
481    /// * If this ID has been written already but with different data, and this tracker is
482    ///   configured to return an error, return an error.
483    /// * If the tracker has not been configured to error on replacement or this dictionary
484    ///   has never been seen before, return `Ok(true)` to indicate that the dictionary was just
485    ///   inserted.
486    pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> PolarsResult<bool> {
487        let values = match array.dtype() {
488            ArrowDataType::Dictionary(key_type, _, _) => {
489                match_integer_type!(key_type, |$T| {
490                    let array = array
491                        .as_any()
492                        .downcast_ref::<DictionaryArray<$T>>()
493                        .unwrap();
494                    array.values()
495                })
496            },
497            _ => unreachable!(),
498        };
499
500        // If a dictionary with this id was already emitted, check if it was the same.
501        if let Some(last) = self.dictionaries.get(&dict_id) {
502            if last.as_ref() == values.as_ref() {
503                // Same dictionary values => no need to emit it again
504                return Ok(false);
505            } else if self.cannot_replace {
506                polars_bail!(InvalidOperation:
507                    "Dictionary replacement detected when writing IPC file format. \
508                     Arrow IPC files only support a single dictionary for a given field \
509                     across all batches."
510                );
511            }
512        };
513
514        self.dictionaries.insert(dict_id, values.clone());
515        Ok(true)
516    }
517}
518
519/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data
520#[derive(Debug, Default)]
521pub struct EncodedData {
522    /// An encoded ipc::Schema::Message
523    pub ipc_message: Vec<u8>,
524    /// Arrow buffers to be written, should be an empty vec for schema messages
525    pub arrow_data: Vec<u8>,
526}
527
528/// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes
529#[inline]
530pub(crate) fn pad_to_64(len: usize) -> usize {
531    ((len + 63) & !63) - len
532}
533
534/// An array [`RecordBatchT`] with optional accompanying IPC fields.
535#[derive(Debug, Clone, PartialEq)]
536pub struct Record<'a> {
537    columns: Cow<'a, RecordBatchT<Box<dyn Array>>>,
538    fields: Option<Cow<'a, [IpcField]>>,
539}
540
541impl Record<'_> {
542    /// Get the IPC fields for this record.
543    pub fn fields(&self) -> Option<&[IpcField]> {
544        self.fields.as_deref()
545    }
546
547    /// Get the Arrow columns in this record.
548    pub fn columns(&self) -> &RecordBatchT<Box<dyn Array>> {
549        self.columns.borrow()
550    }
551}
552
553impl From<RecordBatchT<Box<dyn Array>>> for Record<'static> {
554    fn from(columns: RecordBatchT<Box<dyn Array>>) -> Self {
555        Self {
556            columns: Cow::Owned(columns),
557            fields: None,
558        }
559    }
560}
561
562impl<'a, F> From<(RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
563where
564    F: Into<Cow<'a, [IpcField]>>,
565{
566    fn from((columns, fields): (RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
567        Self {
568            columns: Cow::Owned(columns),
569            fields: fields.map(|f| f.into()),
570        }
571    }
572}
573
574impl<'a, F> From<(&'a RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
575where
576    F: Into<Cow<'a, [IpcField]>>,
577{
578    fn from((columns, fields): (&'a RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
579        Self {
580            columns: Cow::Borrowed(columns),
581            fields: fields.map(|f| f.into()),
582        }
583    }
584}