1use std::borrow::{Borrow, Cow};
2
3use arrow_format::ipc;
4use arrow_format::ipc::planus::Builder;
5use polars_error::{PolarsResult, polars_bail, polars_err};
6
7use super::super::IpcField;
8use super::{write, write_dictionary};
9use crate::array::*;
10use crate::datatypes::*;
11use crate::io::ipc::endianness::is_native_little_endian;
12use crate::io::ipc::read::Dictionaries;
13use crate::legacy::prelude::LargeListArray;
14use crate::match_integer_type;
15use crate::record_batch::RecordBatchT;
16use crate::types::Index;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum Compression {
21 LZ4,
23 ZSTD,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
29pub struct WriteOptions {
30 pub compression: Option<Compression>,
33}
34
35pub fn dictionaries_to_encode(
37 field: &IpcField,
38 array: &dyn Array,
39 dictionary_tracker: &mut DictionaryTracker,
40 dicts_to_encode: &mut Vec<(i64, Box<dyn Array>)>,
41) -> PolarsResult<()> {
42 use PhysicalType::*;
43 match array.dtype().to_physical_type() {
44 Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null
45 | FixedSizeBinary | BinaryView | Utf8View => Ok(()),
46 Dictionary(key_type) => match_integer_type!(key_type, |$T| {
47 let dict_id = field.dictionary_id
48 .ok_or_else(|| polars_err!(InvalidOperation: "Dictionaries must have an associated id"))?;
49
50 if dictionary_tracker.insert(dict_id, array)? {
51 dicts_to_encode.push((dict_id, array.to_boxed()));
52 }
53
54 let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
55 let values = array.values();
56 dictionaries_to_encode(field,
58 values.as_ref(),
59 dictionary_tracker,
60 dicts_to_encode,
61 )?;
62
63 Ok(())
64 }),
65 Struct => {
66 let array = array.as_any().downcast_ref::<StructArray>().unwrap();
67 let fields = field.fields.as_slice();
68 if array.fields().len() != fields.len() {
69 polars_bail!(InvalidOperation:
70 "The number of fields in a struct must equal the number of children in IpcField".to_string(),
71 );
72 }
73 fields
74 .iter()
75 .zip(array.values().iter())
76 .try_for_each(|(field, values)| {
77 dictionaries_to_encode(
78 field,
79 values.as_ref(),
80 dictionary_tracker,
81 dicts_to_encode,
82 )
83 })
84 },
85 List => {
86 let values = array
87 .as_any()
88 .downcast_ref::<ListArray<i32>>()
89 .unwrap()
90 .values();
91 let field = &field.fields[0]; dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
93 },
94 LargeList => {
95 let values = array
96 .as_any()
97 .downcast_ref::<ListArray<i64>>()
98 .unwrap()
99 .values();
100 let field = &field.fields[0]; dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
102 },
103 FixedSizeList => {
104 let values = array
105 .as_any()
106 .downcast_ref::<FixedSizeListArray>()
107 .unwrap()
108 .values();
109 let field = &field.fields[0]; dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
111 },
112 Union => {
113 let values = array
114 .as_any()
115 .downcast_ref::<UnionArray>()
116 .unwrap()
117 .fields();
118 let fields = &field.fields[..]; if values.len() != fields.len() {
120 polars_bail!(InvalidOperation:
121 "The number of fields in a union must equal the number of children in IpcField"
122 );
123 }
124 fields
125 .iter()
126 .zip(values.iter())
127 .try_for_each(|(field, values)| {
128 dictionaries_to_encode(
129 field,
130 values.as_ref(),
131 dictionary_tracker,
132 dicts_to_encode,
133 )
134 })
135 },
136 Map => {
137 let values = array.as_any().downcast_ref::<MapArray>().unwrap().field();
138 let field = &field.fields[0]; dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
140 },
141 }
142}
143
144pub fn encode_dictionary(
150 dict_id: i64,
151 array: &dyn Array,
152 options: &WriteOptions,
153 encoded_dictionaries: &mut Vec<EncodedData>,
154) -> PolarsResult<()> {
155 let PhysicalType::Dictionary(key_type) = array.dtype().to_physical_type() else {
156 panic!("Given array is not a DictionaryArray")
157 };
158
159 match_integer_type!(key_type, |$T| {
160 let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
161 encoded_dictionaries.push(dictionary_batch_to_bytes::<$T>(
162 dict_id,
163 array,
164 options,
165 is_native_little_endian(),
166 ));
167 });
168
169 Ok(())
170}
171
172pub fn encode_new_dictionaries(
173 field: &IpcField,
174 array: &dyn Array,
175 options: &WriteOptions,
176 dictionary_tracker: &mut DictionaryTracker,
177 encoded_dictionaries: &mut Vec<EncodedData>,
178) -> PolarsResult<()> {
179 let mut dicts_to_encode = Vec::new();
180 dictionaries_to_encode(field, array, dictionary_tracker, &mut dicts_to_encode)?;
181 for (dict_id, dict_array) in dicts_to_encode {
182 encode_dictionary(dict_id, dict_array.as_ref(), options, encoded_dictionaries)?;
183 }
184 Ok(())
185}
186
187pub fn encode_chunk(
188 chunk: &RecordBatchT<Box<dyn Array>>,
189 fields: &[IpcField],
190 dictionary_tracker: &mut DictionaryTracker,
191 options: &WriteOptions,
192) -> PolarsResult<(Vec<EncodedData>, EncodedData)> {
193 let mut encoded_message = EncodedData::default();
194 let encoded_dictionaries = encode_chunk_amortized(
195 chunk,
196 fields,
197 dictionary_tracker,
198 options,
199 &mut encoded_message,
200 )?;
201 Ok((encoded_dictionaries, encoded_message))
202}
203
204pub fn encode_chunk_amortized(
206 chunk: &RecordBatchT<Box<dyn Array>>,
207 fields: &[IpcField],
208 dictionary_tracker: &mut DictionaryTracker,
209 options: &WriteOptions,
210 encoded_message: &mut EncodedData,
211) -> PolarsResult<Vec<EncodedData>> {
212 let mut encoded_dictionaries = vec![];
213
214 for (field, array) in fields.iter().zip(chunk.as_ref()) {
215 encode_new_dictionaries(
216 field,
217 array.as_ref(),
218 options,
219 dictionary_tracker,
220 &mut encoded_dictionaries,
221 )?;
222 }
223 encode_record_batch(chunk, options, encoded_message);
224
225 Ok(encoded_dictionaries)
226}
227
228fn serialize_compression(
229 compression: Option<Compression>,
230) -> Option<Box<arrow_format::ipc::BodyCompression>> {
231 if let Some(compression) = compression {
232 let codec = match compression {
233 Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame,
234 Compression::ZSTD => arrow_format::ipc::CompressionType::Zstd,
235 };
236 Some(Box::new(arrow_format::ipc::BodyCompression {
237 codec,
238 method: arrow_format::ipc::BodyCompressionMethod::Buffer,
239 }))
240 } else {
241 None
242 }
243}
244
245fn set_variadic_buffer_counts(counts: &mut Vec<i64>, array: &dyn Array) {
246 match array.dtype() {
247 ArrowDataType::Utf8View => {
248 let array = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
249 counts.push(array.data_buffers().len() as i64);
250 },
251 ArrowDataType::BinaryView => {
252 let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
253 counts.push(array.data_buffers().len() as i64);
254 },
255 ArrowDataType::Struct(_) => {
256 let array = array.as_any().downcast_ref::<StructArray>().unwrap();
257 for array in array.values() {
258 set_variadic_buffer_counts(counts, array.as_ref())
259 }
260 },
261 ArrowDataType::LargeList(_) => {
262 let array = array.as_any().downcast_ref::<LargeListArray>().unwrap();
265 let offsets = array.offsets().buffer();
266 let first = *offsets.first().unwrap();
267 let last = *offsets.last().unwrap();
268 let subslice = array
269 .values()
270 .sliced(first.to_usize(), last.to_usize() - first.to_usize());
271 set_variadic_buffer_counts(counts, &*subslice)
272 },
273 ArrowDataType::FixedSizeList(_, _) => {
274 let array = array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
275 set_variadic_buffer_counts(counts, array.values().as_ref())
276 },
277 ArrowDataType::Dictionary(_, _, _) => (),
280 _ => (),
281 }
282}
283
284fn gc_bin_view<'a, T: ViewType + ?Sized>(
285 arr: &'a Box<dyn Array>,
286 concrete_arr: &'a BinaryViewArrayGeneric<T>,
287) -> Cow<'a, Box<dyn Array>> {
288 let bytes_len = concrete_arr.total_bytes_len();
289 let buffer_len = concrete_arr.total_buffer_len();
290 let extra_len = buffer_len.saturating_sub(bytes_len);
291 if extra_len < bytes_len.min(1024) {
292 Cow::Borrowed(arr)
294 } else {
295 Cow::Owned(concrete_arr.clone().gc().boxed())
297 }
298}
299
300pub fn encode_array(
301 array: &Box<dyn Array>,
302 options: &WriteOptions,
303 variadic_buffer_counts: &mut Vec<i64>,
304 buffers: &mut Vec<ipc::Buffer>,
305 arrow_data: &mut Vec<u8>,
306 nodes: &mut Vec<ipc::FieldNode>,
307 offset: &mut i64,
308) {
309 let array = match array.dtype() {
311 ArrowDataType::BinaryView => {
312 let concrete_arr = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
313 gc_bin_view(array, concrete_arr)
314 },
315 ArrowDataType::Utf8View => {
316 let concrete_arr = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
317 gc_bin_view(array, concrete_arr)
318 },
319 _ => Cow::Borrowed(array),
320 };
321 let array = array.as_ref().as_ref();
322
323 set_variadic_buffer_counts(variadic_buffer_counts, array);
324
325 write(
326 array,
327 buffers,
328 arrow_data,
329 nodes,
330 offset,
331 is_native_little_endian(),
332 options.compression,
333 )
334}
335
336pub fn encode_record_batch(
339 chunk: &RecordBatchT<Box<dyn Array>>,
340 options: &WriteOptions,
341 encoded_message: &mut EncodedData,
342) {
343 let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
344 let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
345 encoded_message.arrow_data.clear();
346
347 let mut offset = 0;
348 let mut variadic_buffer_counts = vec![];
349 for array in chunk.arrays() {
350 encode_array(
351 array,
352 options,
353 &mut variadic_buffer_counts,
354 &mut buffers,
355 &mut encoded_message.arrow_data,
356 &mut nodes,
357 &mut offset,
358 );
359 }
360
361 commit_encoded_arrays(
362 chunk.len(),
363 options,
364 variadic_buffer_counts,
365 buffers,
366 nodes,
367 encoded_message,
368 );
369}
370
371pub fn commit_encoded_arrays(
372 array_len: usize,
373 options: &WriteOptions,
374 variadic_buffer_counts: Vec<i64>,
375 buffers: Vec<ipc::Buffer>,
376 nodes: Vec<ipc::FieldNode>,
377 encoded_message: &mut EncodedData,
378) {
379 let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
380 None
381 } else {
382 Some(variadic_buffer_counts)
383 };
384
385 let compression = serialize_compression(options.compression);
386
387 let message = arrow_format::ipc::Message {
388 version: arrow_format::ipc::MetadataVersion::V5,
389 header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new(
390 arrow_format::ipc::RecordBatch {
391 length: array_len as i64,
392 nodes: Some(nodes),
393 buffers: Some(buffers),
394 compression,
395 variadic_buffer_counts,
396 },
397 ))),
398 body_length: encoded_message.arrow_data.len() as i64,
399 custom_metadata: None,
400 };
401
402 let mut builder = Builder::new();
403 let ipc_message = builder.finish(&message, None);
404 encoded_message.ipc_message = ipc_message.to_vec();
405}
406
407fn dictionary_batch_to_bytes<K: DictionaryKey>(
410 dict_id: i64,
411 array: &DictionaryArray<K>,
412 options: &WriteOptions,
413 is_little_endian: bool,
414) -> EncodedData {
415 let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
416 let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
417 let mut arrow_data: Vec<u8> = vec![];
418 let mut variadic_buffer_counts = vec![];
419 set_variadic_buffer_counts(&mut variadic_buffer_counts, array.values().as_ref());
420
421 let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
422 None
423 } else {
424 Some(variadic_buffer_counts)
425 };
426
427 let length = write_dictionary(
428 array,
429 &mut buffers,
430 &mut arrow_data,
431 &mut nodes,
432 &mut 0,
433 is_little_endian,
434 options.compression,
435 false,
436 );
437
438 let compression = serialize_compression(options.compression);
439
440 let message = arrow_format::ipc::Message {
441 version: arrow_format::ipc::MetadataVersion::V5,
442 header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new(
443 arrow_format::ipc::DictionaryBatch {
444 id: dict_id,
445 data: Some(Box::new(arrow_format::ipc::RecordBatch {
446 length: length as i64,
447 nodes: Some(nodes),
448 buffers: Some(buffers),
449 compression,
450 variadic_buffer_counts,
451 })),
452 is_delta: false,
453 },
454 ))),
455 body_length: arrow_data.len() as i64,
456 custom_metadata: None,
457 };
458
459 let mut builder = Builder::new();
460 let ipc_message = builder.finish(&message, None);
461
462 EncodedData {
463 ipc_message: ipc_message.to_vec(),
464 arrow_data,
465 }
466}
467
468pub struct DictionaryTracker {
472 pub dictionaries: Dictionaries,
473 pub cannot_replace: bool,
474}
475
476impl DictionaryTracker {
477 pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> PolarsResult<bool> {
487 let values = match array.dtype() {
488 ArrowDataType::Dictionary(key_type, _, _) => {
489 match_integer_type!(key_type, |$T| {
490 let array = array
491 .as_any()
492 .downcast_ref::<DictionaryArray<$T>>()
493 .unwrap();
494 array.values()
495 })
496 },
497 _ => unreachable!(),
498 };
499
500 if let Some(last) = self.dictionaries.get(&dict_id) {
502 if last.as_ref() == values.as_ref() {
503 return Ok(false);
505 } else if self.cannot_replace {
506 polars_bail!(InvalidOperation:
507 "Dictionary replacement detected when writing IPC file format. \
508 Arrow IPC files only support a single dictionary for a given field \
509 across all batches."
510 );
511 }
512 };
513
514 self.dictionaries.insert(dict_id, values.clone());
515 Ok(true)
516 }
517}
518
519#[derive(Debug, Default)]
521pub struct EncodedData {
522 pub ipc_message: Vec<u8>,
524 pub arrow_data: Vec<u8>,
526}
527
528#[inline]
530pub(crate) fn pad_to_64(len: usize) -> usize {
531 ((len + 63) & !63) - len
532}
533
534#[derive(Debug, Clone, PartialEq)]
536pub struct Record<'a> {
537 columns: Cow<'a, RecordBatchT<Box<dyn Array>>>,
538 fields: Option<Cow<'a, [IpcField]>>,
539}
540
541impl Record<'_> {
542 pub fn fields(&self) -> Option<&[IpcField]> {
544 self.fields.as_deref()
545 }
546
547 pub fn columns(&self) -> &RecordBatchT<Box<dyn Array>> {
549 self.columns.borrow()
550 }
551}
552
553impl From<RecordBatchT<Box<dyn Array>>> for Record<'static> {
554 fn from(columns: RecordBatchT<Box<dyn Array>>) -> Self {
555 Self {
556 columns: Cow::Owned(columns),
557 fields: None,
558 }
559 }
560}
561
562impl<'a, F> From<(RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
563where
564 F: Into<Cow<'a, [IpcField]>>,
565{
566 fn from((columns, fields): (RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
567 Self {
568 columns: Cow::Owned(columns),
569 fields: fields.map(|f| f.into()),
570 }
571 }
572}
573
574impl<'a, F> From<(&'a RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
575where
576 F: Into<Cow<'a, [IpcField]>>,
577{
578 fn from((columns, fields): (&'a RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
579 Self {
580 columns: Cow::Borrowed(columns),
581 fields: fields.map(|f| f.into()),
582 }
583 }
584}