arrow2/io/ipc/write/
common.rs

1use std::borrow::{Borrow, Cow};
2
3use arrow_format::ipc::planus::Builder;
4
5use crate::array::*;
6use crate::chunk::Chunk;
7use crate::datatypes::*;
8use crate::error::{Error, Result};
9use crate::io::ipc::endianess::is_native_little_endian;
10use crate::io::ipc::read::Dictionaries;
11
12use super::super::IpcField;
13use super::{write, write_dictionary};
14
15/// Compression codec
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum Compression {
18    /// LZ4 (framed)
19    LZ4,
20    /// ZSTD
21    ZSTD,
22}
23
24/// Options declaring the behaviour of writing to IPC
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
26pub struct WriteOptions {
27    /// Whether the buffers should be compressed and which codec to use.
28    /// Note: to use compression the crate must be compiled with feature `io_ipc_compression`.
29    pub compression: Option<Compression>,
30}
31
32fn encode_dictionary(
33    field: &IpcField,
34    array: &dyn Array,
35    options: &WriteOptions,
36    dictionary_tracker: &mut DictionaryTracker,
37    encoded_dictionaries: &mut Vec<EncodedData>,
38) -> Result<()> {
39    use PhysicalType::*;
40    match array.data_type().to_physical_type() {
41        Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null
42        | FixedSizeBinary => Ok(()),
43        Dictionary(key_type) => match_integer_type!(key_type, |$T| {
44            let dict_id = field.dictionary_id
45                .ok_or_else(|| Error::InvalidArgumentError("Dictionaries must have an associated id".to_string()))?;
46
47            let emit = dictionary_tracker.insert(dict_id, array)?;
48
49            let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
50            let values = array.values();
51            encode_dictionary(field,
52                values.as_ref(),
53                options,
54                dictionary_tracker,
55                encoded_dictionaries
56            )?;
57
58            if emit {
59                encoded_dictionaries.push(dictionary_batch_to_bytes::<$T>(
60                    dict_id,
61                    array,
62                    options,
63                    is_native_little_endian(),
64                ));
65            };
66            Ok(())
67        }),
68        Struct => {
69            let array = array.as_any().downcast_ref::<StructArray>().unwrap();
70            let fields = field.fields.as_slice();
71            if array.fields().len() != fields.len() {
72                return Err(Error::InvalidArgumentError(
73                    "The number of fields in a struct must equal the number of children in IpcField".to_string(),
74                ));
75            }
76            fields
77                .iter()
78                .zip(array.values().iter())
79                .try_for_each(|(field, values)| {
80                    encode_dictionary(
81                        field,
82                        values.as_ref(),
83                        options,
84                        dictionary_tracker,
85                        encoded_dictionaries,
86                    )
87                })
88        }
89        List => {
90            let values = array
91                .as_any()
92                .downcast_ref::<ListArray<i32>>()
93                .unwrap()
94                .values();
95            let field = &field.fields[0]; // todo: error instead
96            encode_dictionary(
97                field,
98                values.as_ref(),
99                options,
100                dictionary_tracker,
101                encoded_dictionaries,
102            )
103        }
104        LargeList => {
105            let values = array
106                .as_any()
107                .downcast_ref::<ListArray<i64>>()
108                .unwrap()
109                .values();
110            let field = &field.fields[0]; // todo: error instead
111            encode_dictionary(
112                field,
113                values.as_ref(),
114                options,
115                dictionary_tracker,
116                encoded_dictionaries,
117            )
118        }
119        FixedSizeList => {
120            let values = array
121                .as_any()
122                .downcast_ref::<FixedSizeListArray>()
123                .unwrap()
124                .values();
125            let field = &field.fields[0]; // todo: error instead
126            encode_dictionary(
127                field,
128                values.as_ref(),
129                options,
130                dictionary_tracker,
131                encoded_dictionaries,
132            )
133        }
134        Union => {
135            let values = array
136                .as_any()
137                .downcast_ref::<UnionArray>()
138                .unwrap()
139                .fields();
140            let fields = &field.fields[..]; // todo: error instead
141            if values.len() != fields.len() {
142                return Err(Error::InvalidArgumentError(
143                    "The number of fields in a union must equal the number of children in IpcField"
144                        .to_string(),
145                ));
146            }
147            fields
148                .iter()
149                .zip(values.iter())
150                .try_for_each(|(field, values)| {
151                    encode_dictionary(
152                        field,
153                        values.as_ref(),
154                        options,
155                        dictionary_tracker,
156                        encoded_dictionaries,
157                    )
158                })
159        }
160        Map => {
161            let values = array.as_any().downcast_ref::<MapArray>().unwrap().field();
162            let field = &field.fields[0]; // todo: error instead
163            encode_dictionary(
164                field,
165                values.as_ref(),
166                options,
167                dictionary_tracker,
168                encoded_dictionaries,
169            )
170        }
171    }
172}
173
174pub fn encode_chunk(
175    chunk: &Chunk<Box<dyn Array>>,
176    fields: &[IpcField],
177    dictionary_tracker: &mut DictionaryTracker,
178    options: &WriteOptions,
179) -> Result<(Vec<EncodedData>, EncodedData)> {
180    let mut encoded_message = EncodedData::default();
181    let encoded_dictionaries = encode_chunk_amortized(
182        chunk,
183        fields,
184        dictionary_tracker,
185        options,
186        &mut encoded_message,
187    )?;
188    Ok((encoded_dictionaries, encoded_message))
189}
190
191// Amortizes `EncodedData` allocation.
192pub fn encode_chunk_amortized(
193    chunk: &Chunk<Box<dyn Array>>,
194    fields: &[IpcField],
195    dictionary_tracker: &mut DictionaryTracker,
196    options: &WriteOptions,
197    encoded_message: &mut EncodedData,
198) -> Result<Vec<EncodedData>> {
199    let mut encoded_dictionaries = vec![];
200
201    for (field, array) in fields.iter().zip(chunk.as_ref()) {
202        encode_dictionary(
203            field,
204            array.as_ref(),
205            options,
206            dictionary_tracker,
207            &mut encoded_dictionaries,
208        )?;
209    }
210
211    chunk_to_bytes_amortized(chunk, options, encoded_message);
212
213    Ok(encoded_dictionaries)
214}
215
216fn serialize_compression(
217    compression: Option<Compression>,
218) -> Option<Box<arrow_format::ipc::BodyCompression>> {
219    if let Some(compression) = compression {
220        let codec = match compression {
221            Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame,
222            Compression::ZSTD => arrow_format::ipc::CompressionType::Zstd,
223        };
224        Some(Box::new(arrow_format::ipc::BodyCompression {
225            codec,
226            method: arrow_format::ipc::BodyCompressionMethod::Buffer,
227        }))
228    } else {
229        None
230    }
231}
232
233/// Write [`Chunk`] into two sets of bytes, one for the header (ipc::Schema::Message) and the
234/// other for the batch's data
235fn chunk_to_bytes_amortized(
236    chunk: &Chunk<Box<dyn Array>>,
237    options: &WriteOptions,
238    encoded_message: &mut EncodedData,
239) {
240    let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
241    let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
242    let mut arrow_data = std::mem::take(&mut encoded_message.arrow_data);
243    arrow_data.clear();
244
245    let mut offset = 0;
246    for array in chunk.arrays() {
247        write(
248            array.as_ref(),
249            &mut buffers,
250            &mut arrow_data,
251            &mut nodes,
252            &mut offset,
253            is_native_little_endian(),
254            options.compression,
255        )
256    }
257
258    let compression = serialize_compression(options.compression);
259
260    let message = arrow_format::ipc::Message {
261        version: arrow_format::ipc::MetadataVersion::V5,
262        header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new(
263            arrow_format::ipc::RecordBatch {
264                length: chunk.len() as i64,
265                nodes: Some(nodes),
266                buffers: Some(buffers),
267                compression,
268            },
269        ))),
270        body_length: arrow_data.len() as i64,
271        custom_metadata: None,
272    };
273
274    let mut builder = Builder::new();
275    let ipc_message = builder.finish(&message, None);
276    encoded_message.ipc_message = ipc_message.to_vec();
277    encoded_message.arrow_data = arrow_data
278}
279
280/// Write dictionary values into two sets of bytes, one for the header (ipc::Schema::Message) and the
281/// other for the data
282fn dictionary_batch_to_bytes<K: DictionaryKey>(
283    dict_id: i64,
284    array: &DictionaryArray<K>,
285    options: &WriteOptions,
286    is_little_endian: bool,
287) -> EncodedData {
288    let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
289    let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
290    let mut arrow_data: Vec<u8> = vec![];
291
292    let length = write_dictionary(
293        array,
294        &mut buffers,
295        &mut arrow_data,
296        &mut nodes,
297        &mut 0,
298        is_little_endian,
299        options.compression,
300        false,
301    );
302
303    let compression = serialize_compression(options.compression);
304
305    let message = arrow_format::ipc::Message {
306        version: arrow_format::ipc::MetadataVersion::V5,
307        header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new(
308            arrow_format::ipc::DictionaryBatch {
309                id: dict_id,
310                data: Some(Box::new(arrow_format::ipc::RecordBatch {
311                    length: length as i64,
312                    nodes: Some(nodes),
313                    buffers: Some(buffers),
314                    compression,
315                })),
316                is_delta: false,
317            },
318        ))),
319        body_length: arrow_data.len() as i64,
320        custom_metadata: None,
321    };
322
323    let mut builder = Builder::new();
324    let ipc_message = builder.finish(&message, None);
325
326    EncodedData {
327        ipc_message: ipc_message.to_vec(),
328        arrow_data,
329    }
330}
331
332/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary
333/// multiple times. Can optionally error if an update to an existing dictionary is attempted, which
334/// isn't allowed in the `FileWriter`.
335pub struct DictionaryTracker {
336    pub dictionaries: Dictionaries,
337    pub cannot_replace: bool,
338}
339
340impl DictionaryTracker {
341    /// Keep track of the dictionary with the given ID and values. Behavior:
342    ///
343    /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate
344    ///   that the dictionary was not actually inserted (because it's already been seen).
345    /// * If this ID has been written already but with different data, and this tracker is
346    ///   configured to return an error, return an error.
347    /// * If the tracker has not been configured to error on replacement or this dictionary
348    ///   has never been seen before, return `Ok(true)` to indicate that the dictionary was just
349    ///   inserted.
350    pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> Result<bool> {
351        let values = match array.data_type() {
352            DataType::Dictionary(key_type, _, _) => {
353                match_integer_type!(key_type, |$T| {
354                    let array = array
355                        .as_any()
356                        .downcast_ref::<DictionaryArray<$T>>()
357                        .unwrap();
358                    array.values()
359                })
360            }
361            _ => unreachable!(),
362        };
363
364        // If a dictionary with this id was already emitted, check if it was the same.
365        if let Some(last) = self.dictionaries.get(&dict_id) {
366            if last.as_ref() == values.as_ref() {
367                // Same dictionary values => no need to emit it again
368                return Ok(false);
369            } else if self.cannot_replace {
370                return Err(Error::InvalidArgumentError(
371                    "Dictionary replacement detected when writing IPC file format. \
372                     Arrow IPC files only support a single dictionary for a given field \
373                     across all batches."
374                        .to_string(),
375                ));
376            }
377        };
378
379        self.dictionaries.insert(dict_id, values.clone());
380        Ok(true)
381    }
382}
383
384/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data
385#[derive(Debug, Default)]
386pub struct EncodedData {
387    /// An encoded ipc::Schema::Message
388    pub ipc_message: Vec<u8>,
389    /// Arrow buffers to be written, should be an empty vec for schema messages
390    pub arrow_data: Vec<u8>,
391}
392
393/// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes
394#[inline]
395pub(crate) fn pad_to_64(len: usize) -> usize {
396    ((len + 63) & !63) - len
397}
398
399/// An array [`Chunk`] with optional accompanying IPC fields.
400#[derive(Debug, Clone, PartialEq)]
401pub struct Record<'a> {
402    columns: Cow<'a, Chunk<Box<dyn Array>>>,
403    fields: Option<Cow<'a, [IpcField]>>,
404}
405
406impl<'a> Record<'a> {
407    /// Get the IPC fields for this record.
408    pub fn fields(&self) -> Option<&[IpcField]> {
409        self.fields.as_deref()
410    }
411
412    /// Get the Arrow columns in this record.
413    pub fn columns(&self) -> &Chunk<Box<dyn Array>> {
414        self.columns.borrow()
415    }
416}
417
418impl From<Chunk<Box<dyn Array>>> for Record<'static> {
419    fn from(columns: Chunk<Box<dyn Array>>) -> Self {
420        Self {
421            columns: Cow::Owned(columns),
422            fields: None,
423        }
424    }
425}
426
427impl<'a, F> From<(Chunk<Box<dyn Array>>, Option<F>)> for Record<'a>
428where
429    F: Into<Cow<'a, [IpcField]>>,
430{
431    fn from((columns, fields): (Chunk<Box<dyn Array>>, Option<F>)) -> Self {
432        Self {
433            columns: Cow::Owned(columns),
434            fields: fields.map(|f| f.into()),
435        }
436    }
437}
438
439impl<'a, F> From<(&'a Chunk<Box<dyn Array>>, Option<F>)> for Record<'a>
440where
441    F: Into<Cow<'a, [IpcField]>>,
442{
443    fn from((columns, fields): (&'a Chunk<Box<dyn Array>>, Option<F>)) -> Self {
444        Self {
445            columns: Cow::Borrowed(columns),
446            fields: fields.map(|f| f.into()),
447        }
448    }
449}