1use std::borrow::{Borrow, Cow};
2
3use arrow_format::ipc;
4use arrow_format::ipc::planus::Builder;
5use polars_error::{PolarsResult, polars_bail, polars_err};
6use polars_utils::compression::ZstdLevel;
7
8use super::super::IpcField;
9use super::write;
10use crate::array::*;
11use crate::datatypes::*;
12use crate::io::ipc::endianness::is_native_little_endian;
13use crate::io::ipc::read::Dictionaries;
14use crate::legacy::prelude::LargeListArray;
15use crate::match_integer_type;
16use crate::record_batch::RecordBatchT;
17use crate::types::Index;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum Compression {
22 LZ4,
24 ZSTD(ZstdLevel),
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
30pub struct WriteOptions {
31 pub compression: Option<Compression>,
34}
35
36pub fn dictionaries_to_encode(
38 field: &IpcField,
39 array: &dyn Array,
40 dictionary_tracker: &mut DictionaryTracker,
41 dicts_to_encode: &mut Vec<(i64, Box<dyn Array>)>,
42) -> PolarsResult<()> {
43 use PhysicalType::*;
44 match array.dtype().to_physical_type() {
45 Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null
46 | FixedSizeBinary | BinaryView | Utf8View => Ok(()),
47 Dictionary(key_type) => match_integer_type!(key_type, |$T| {
48 let dict_id = field.dictionary_id
49 .ok_or_else(|| polars_err!(InvalidOperation: "Dictionaries must have an associated id"))?;
50
51 if dictionary_tracker.insert(dict_id, array)? {
52 dicts_to_encode.push((dict_id, array.to_boxed()));
53 }
54
55 let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
56 let values = array.values();
57 dictionaries_to_encode(field,
59 values.as_ref(),
60 dictionary_tracker,
61 dicts_to_encode,
62 )?;
63
64 Ok(())
65 }),
66 Struct => {
67 let array = array.as_any().downcast_ref::<StructArray>().unwrap();
68 let fields = field.fields.as_slice();
69 if array.fields().len() != fields.len() {
70 polars_bail!(InvalidOperation:
71 "The number of fields in a struct must equal the number of children in IpcField".to_string(),
72 );
73 }
74 fields
75 .iter()
76 .zip(array.values().iter())
77 .try_for_each(|(field, values)| {
78 dictionaries_to_encode(
79 field,
80 values.as_ref(),
81 dictionary_tracker,
82 dicts_to_encode,
83 )
84 })
85 },
86 List => {
87 let values = array
88 .as_any()
89 .downcast_ref::<ListArray<i32>>()
90 .unwrap()
91 .values();
92 let field = &field.fields[0]; dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
94 },
95 LargeList => {
96 let values = array
97 .as_any()
98 .downcast_ref::<ListArray<i64>>()
99 .unwrap()
100 .values();
101 let field = &field.fields[0]; dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
103 },
104 FixedSizeList => {
105 let values = array
106 .as_any()
107 .downcast_ref::<FixedSizeListArray>()
108 .unwrap()
109 .values();
110 let field = &field.fields[0]; dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
112 },
113 Union => {
114 let values = array
115 .as_any()
116 .downcast_ref::<UnionArray>()
117 .unwrap()
118 .fields();
119 let fields = &field.fields[..]; if values.len() != fields.len() {
121 polars_bail!(InvalidOperation:
122 "The number of fields in a union must equal the number of children in IpcField"
123 );
124 }
125 fields
126 .iter()
127 .zip(values.iter())
128 .try_for_each(|(field, values)| {
129 dictionaries_to_encode(
130 field,
131 values.as_ref(),
132 dictionary_tracker,
133 dicts_to_encode,
134 )
135 })
136 },
137 Map => {
138 let values = array.as_any().downcast_ref::<MapArray>().unwrap().field();
139 let field = &field.fields[0]; dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
141 },
142 }
143}
144
145pub fn encode_dictionary(
151 dict_id: i64,
152 array: &dyn Array,
153 options: &WriteOptions,
154) -> PolarsResult<EncodedData> {
155 let PhysicalType::Dictionary(key_type) = array.dtype().to_physical_type() else {
156 panic!("Given array is not a DictionaryArray")
157 };
158
159 match_integer_type!(key_type, |$T| {
160 let array: &DictionaryArray<$T> = array.as_any().downcast_ref().unwrap();
161
162 encode_dictionary_values(dict_id, array.values().as_ref(), options)
163 })
164}
165
166pub fn encode_new_dictionaries(
167 field: &IpcField,
168 array: &dyn Array,
169 options: &WriteOptions,
170 dictionary_tracker: &mut DictionaryTracker,
171 encoded_dictionaries: &mut Vec<EncodedData>,
172) -> PolarsResult<()> {
173 let mut dicts_to_encode = Vec::new();
174 dictionaries_to_encode(field, array, dictionary_tracker, &mut dicts_to_encode)?;
175 for (dict_id, dict_array) in dicts_to_encode {
176 encoded_dictionaries.push(encode_dictionary(dict_id, dict_array.as_ref(), options)?);
177 }
178 Ok(())
179}
180
181pub fn encode_chunk(
182 chunk: &RecordBatchT<Box<dyn Array>>,
183 fields: &[IpcField],
184 dictionary_tracker: &mut DictionaryTracker,
185 options: &WriteOptions,
186) -> PolarsResult<(Vec<EncodedData>, EncodedData)> {
187 let mut encoded_message = EncodedData::default();
188 let encoded_dictionaries = encode_chunk_amortized(
189 chunk,
190 fields,
191 dictionary_tracker,
192 options,
193 &mut encoded_message,
194 )?;
195 Ok((encoded_dictionaries, encoded_message))
196}
197
198pub fn encode_chunk_amortized(
200 chunk: &RecordBatchT<Box<dyn Array>>,
201 fields: &[IpcField],
202 dictionary_tracker: &mut DictionaryTracker,
203 options: &WriteOptions,
204 encoded_message: &mut EncodedData,
205) -> PolarsResult<Vec<EncodedData>> {
206 let mut encoded_dictionaries = vec![];
207
208 for (field, array) in fields.iter().zip(chunk.as_ref()) {
209 encode_new_dictionaries(
210 field,
211 array.as_ref(),
212 options,
213 dictionary_tracker,
214 &mut encoded_dictionaries,
215 )?;
216 }
217 encode_record_batch(chunk, options, encoded_message);
218
219 Ok(encoded_dictionaries)
220}
221
222fn serialize_compression(
223 compression: Option<Compression>,
224) -> Option<Box<arrow_format::ipc::BodyCompression>> {
225 if let Some(compression) = compression {
226 let codec = match compression {
227 Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame,
228 Compression::ZSTD(_) => arrow_format::ipc::CompressionType::Zstd,
229 };
230 Some(Box::new(arrow_format::ipc::BodyCompression {
231 codec,
232 method: arrow_format::ipc::BodyCompressionMethod::Buffer,
233 }))
234 } else {
235 None
236 }
237}
238
239fn set_variadic_buffer_counts(counts: &mut Vec<i64>, array: &dyn Array) {
240 match array.dtype() {
241 ArrowDataType::Utf8View => {
242 let array = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
243 counts.push(array.data_buffers().len() as i64);
244 },
245 ArrowDataType::BinaryView => {
246 let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
247 counts.push(array.data_buffers().len() as i64);
248 },
249 ArrowDataType::Struct(_) => {
250 let array = array.as_any().downcast_ref::<StructArray>().unwrap();
251 for array in array.values() {
252 set_variadic_buffer_counts(counts, array.as_ref())
253 }
254 },
255 ArrowDataType::LargeList(_) => {
256 let array = array.as_any().downcast_ref::<LargeListArray>().unwrap();
259 let offsets = array.offsets().buffer();
260 let first = *offsets.first().unwrap();
261 let last = *offsets.last().unwrap();
262 let subslice = array
263 .values()
264 .sliced(first.to_usize(), last.to_usize() - first.to_usize());
265 set_variadic_buffer_counts(counts, &*subslice)
266 },
267 ArrowDataType::FixedSizeList(_, _) => {
268 let array = array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
269 set_variadic_buffer_counts(counts, array.values().as_ref())
270 },
271 ArrowDataType::Dictionary(_, _, _) => (),
274 _ => (),
275 }
276}
277
278fn gc_bin_view<'a, T: ViewType + ?Sized>(
279 arr: &'a Box<dyn Array>,
280 concrete_arr: &'a BinaryViewArrayGeneric<T>,
281) -> Cow<'a, Box<dyn Array>> {
282 let bytes_len = concrete_arr.total_bytes_len();
283 let buffer_len = concrete_arr.total_buffer_len();
284 let extra_len = buffer_len.saturating_sub(bytes_len);
285 if extra_len < bytes_len.min(1024) {
286 Cow::Borrowed(arr)
288 } else {
289 Cow::Owned(concrete_arr.clone().gc().boxed())
291 }
292}
293
294pub fn encode_array(
295 array: &Box<dyn Array>,
296 options: &WriteOptions,
297 variadic_buffer_counts: &mut Vec<i64>,
298 buffers: &mut Vec<ipc::Buffer>,
299 arrow_data: &mut Vec<u8>,
300 nodes: &mut Vec<ipc::FieldNode>,
301 offset: &mut i64,
302) {
303 let array = match array.dtype() {
305 ArrowDataType::BinaryView => {
306 let concrete_arr = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
307 gc_bin_view(array, concrete_arr)
308 },
309 ArrowDataType::Utf8View => {
310 let concrete_arr = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
311 gc_bin_view(array, concrete_arr)
312 },
313 _ => Cow::Borrowed(array),
314 };
315 let array = array.as_ref().as_ref();
316
317 set_variadic_buffer_counts(variadic_buffer_counts, array);
318
319 write(
320 array,
321 buffers,
322 arrow_data,
323 nodes,
324 offset,
325 is_native_little_endian(),
326 options.compression,
327 )
328}
329
330pub fn encode_record_batch(
333 chunk: &RecordBatchT<Box<dyn Array>>,
334 options: &WriteOptions,
335 encoded_message: &mut EncodedData,
336) {
337 let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
338 let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
339 encoded_message.arrow_data.clear();
340
341 let mut offset = 0;
342 let mut variadic_buffer_counts = vec![];
343 for array in chunk.arrays() {
344 encode_array(
345 array,
346 options,
347 &mut variadic_buffer_counts,
348 &mut buffers,
349 &mut encoded_message.arrow_data,
350 &mut nodes,
351 &mut offset,
352 );
353 }
354
355 commit_encoded_arrays(
356 chunk.len(),
357 options,
358 variadic_buffer_counts,
359 buffers,
360 nodes,
361 encoded_message,
362 );
363}
364
365pub fn commit_encoded_arrays(
366 array_len: usize,
367 options: &WriteOptions,
368 variadic_buffer_counts: Vec<i64>,
369 buffers: Vec<ipc::Buffer>,
370 nodes: Vec<ipc::FieldNode>,
371 encoded_message: &mut EncodedData,
372) {
373 let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
374 None
375 } else {
376 Some(variadic_buffer_counts)
377 };
378
379 let compression = serialize_compression(options.compression);
380
381 let message = arrow_format::ipc::Message {
382 version: arrow_format::ipc::MetadataVersion::V5,
383 header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new(
384 arrow_format::ipc::RecordBatch {
385 length: array_len as i64,
386 nodes: Some(nodes),
387 buffers: Some(buffers),
388 compression,
389 variadic_buffer_counts,
390 },
391 ))),
392 body_length: encoded_message.arrow_data.len() as i64,
393 custom_metadata: None,
394 };
395
396 let mut builder = Builder::new();
397 let ipc_message = builder.finish(&message, None);
398 encoded_message.ipc_message = ipc_message.to_vec();
399}
400
401pub fn encode_dictionary_values(
402 dict_id: i64,
403 values_array: &dyn Array,
404 options: &WriteOptions,
405) -> PolarsResult<EncodedData> {
406 let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
407 let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
408 let mut arrow_data: Vec<u8> = vec![];
409 let mut variadic_buffer_counts = vec![];
410 set_variadic_buffer_counts(&mut variadic_buffer_counts, values_array);
411
412 let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
413 None
414 } else {
415 Some(variadic_buffer_counts)
416 };
417
418 write(
419 values_array,
420 &mut buffers,
421 &mut arrow_data,
422 &mut nodes,
423 &mut 0,
424 is_native_little_endian(),
425 options.compression,
426 );
427
428 let compression = serialize_compression(options.compression);
429
430 let message = arrow_format::ipc::Message {
431 version: arrow_format::ipc::MetadataVersion::V5,
432 header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new(
433 arrow_format::ipc::DictionaryBatch {
434 id: dict_id,
435 data: Some(Box::new(arrow_format::ipc::RecordBatch {
436 length: values_array.len() as i64,
437 nodes: Some(nodes),
438 buffers: Some(buffers),
439 compression,
440 variadic_buffer_counts,
441 })),
442 is_delta: false,
443 },
444 ))),
445 body_length: arrow_data.len() as i64,
446 custom_metadata: None,
447 };
448
449 let mut builder = Builder::new();
450 let ipc_message = builder.finish(&message, None);
451
452 Ok(EncodedData {
453 ipc_message: ipc_message.to_vec(),
454 arrow_data,
455 })
456}
457
458pub struct DictionaryTracker {
462 pub dictionaries: Dictionaries,
463 pub cannot_replace: bool,
464}
465
466impl DictionaryTracker {
467 pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> PolarsResult<bool> {
477 let values = match array.dtype() {
478 ArrowDataType::Dictionary(key_type, _, _) => {
479 match_integer_type!(key_type, |$T| {
480 let array = array
481 .as_any()
482 .downcast_ref::<DictionaryArray<$T>>()
483 .unwrap();
484 array.values()
485 })
486 },
487 _ => unreachable!(),
488 };
489
490 if let Some(last) = self.dictionaries.get(&dict_id) {
492 if last.as_ref() == values.as_ref() {
493 return Ok(false);
495 } else if self.cannot_replace {
496 polars_bail!(InvalidOperation:
497 "Dictionary replacement detected when writing IPC file format. \
498 Arrow IPC files only support a single dictionary for a given field \
499 across all batches."
500 );
501 }
502 };
503
504 self.dictionaries.insert(dict_id, values.clone());
505 Ok(true)
506 }
507}
508
509#[derive(Debug, Default)]
511pub struct EncodedData {
512 pub ipc_message: Vec<u8>,
514 pub arrow_data: Vec<u8>,
516}
517
518#[inline]
520pub(crate) fn pad_to_64(len: usize) -> usize {
521 ((len + 63) & !63) - len
522}
523
524#[derive(Debug, Clone, PartialEq)]
526pub struct Record<'a> {
527 columns: Cow<'a, RecordBatchT<Box<dyn Array>>>,
528 fields: Option<Cow<'a, [IpcField]>>,
529}
530
531impl Record<'_> {
532 pub fn fields(&self) -> Option<&[IpcField]> {
534 self.fields.as_deref()
535 }
536
537 pub fn columns(&self) -> &RecordBatchT<Box<dyn Array>> {
539 self.columns.borrow()
540 }
541}
542
543impl From<RecordBatchT<Box<dyn Array>>> for Record<'static> {
544 fn from(columns: RecordBatchT<Box<dyn Array>>) -> Self {
545 Self {
546 columns: Cow::Owned(columns),
547 fields: None,
548 }
549 }
550}
551
552impl<'a, F> From<(RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
553where
554 F: Into<Cow<'a, [IpcField]>>,
555{
556 fn from((columns, fields): (RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
557 Self {
558 columns: Cow::Owned(columns),
559 fields: fields.map(|f| f.into()),
560 }
561 }
562}
563
564impl<'a, F> From<(&'a RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
565where
566 F: Into<Cow<'a, [IpcField]>>,
567{
568 fn from((columns, fields): (&'a RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
569 Self {
570 columns: Cow::Borrowed(columns),
571 fields: fields.map(|f| f.into()),
572 }
573 }
574}