1use std::borrow::{Borrow, Cow};
2
3use arrow_format::ipc::planus::Builder;
4
5use crate::array::*;
6use crate::chunk::Chunk;
7use crate::datatypes::*;
8use crate::error::{Error, Result};
9use crate::io::ipc::endianess::is_native_little_endian;
10use crate::io::ipc::read::Dictionaries;
11
12use super::super::IpcField;
13use super::{write, write_dictionary};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum Compression {
18 LZ4,
20 ZSTD,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
26pub struct WriteOptions {
27 pub compression: Option<Compression>,
30}
31
32fn encode_dictionary(
33 field: &IpcField,
34 array: &dyn Array,
35 options: &WriteOptions,
36 dictionary_tracker: &mut DictionaryTracker,
37 encoded_dictionaries: &mut Vec<EncodedData>,
38) -> Result<()> {
39 use PhysicalType::*;
40 match array.data_type().to_physical_type() {
41 Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null
42 | FixedSizeBinary => Ok(()),
43 Dictionary(key_type) => match_integer_type!(key_type, |$T| {
44 let dict_id = field.dictionary_id
45 .ok_or_else(|| Error::InvalidArgumentError("Dictionaries must have an associated id".to_string()))?;
46
47 let emit = dictionary_tracker.insert(dict_id, array)?;
48
49 let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
50 let values = array.values();
51 encode_dictionary(field,
52 values.as_ref(),
53 options,
54 dictionary_tracker,
55 encoded_dictionaries
56 )?;
57
58 if emit {
59 encoded_dictionaries.push(dictionary_batch_to_bytes::<$T>(
60 dict_id,
61 array,
62 options,
63 is_native_little_endian(),
64 ));
65 };
66 Ok(())
67 }),
68 Struct => {
69 let array = array.as_any().downcast_ref::<StructArray>().unwrap();
70 let fields = field.fields.as_slice();
71 if array.fields().len() != fields.len() {
72 return Err(Error::InvalidArgumentError(
73 "The number of fields in a struct must equal the number of children in IpcField".to_string(),
74 ));
75 }
76 fields
77 .iter()
78 .zip(array.values().iter())
79 .try_for_each(|(field, values)| {
80 encode_dictionary(
81 field,
82 values.as_ref(),
83 options,
84 dictionary_tracker,
85 encoded_dictionaries,
86 )
87 })
88 }
89 List => {
90 let values = array
91 .as_any()
92 .downcast_ref::<ListArray<i32>>()
93 .unwrap()
94 .values();
95 let field = &field.fields[0]; encode_dictionary(
97 field,
98 values.as_ref(),
99 options,
100 dictionary_tracker,
101 encoded_dictionaries,
102 )
103 }
104 LargeList => {
105 let values = array
106 .as_any()
107 .downcast_ref::<ListArray<i64>>()
108 .unwrap()
109 .values();
110 let field = &field.fields[0]; encode_dictionary(
112 field,
113 values.as_ref(),
114 options,
115 dictionary_tracker,
116 encoded_dictionaries,
117 )
118 }
119 FixedSizeList => {
120 let values = array
121 .as_any()
122 .downcast_ref::<FixedSizeListArray>()
123 .unwrap()
124 .values();
125 let field = &field.fields[0]; encode_dictionary(
127 field,
128 values.as_ref(),
129 options,
130 dictionary_tracker,
131 encoded_dictionaries,
132 )
133 }
134 Union => {
135 let values = array
136 .as_any()
137 .downcast_ref::<UnionArray>()
138 .unwrap()
139 .fields();
140 let fields = &field.fields[..]; if values.len() != fields.len() {
142 return Err(Error::InvalidArgumentError(
143 "The number of fields in a union must equal the number of children in IpcField"
144 .to_string(),
145 ));
146 }
147 fields
148 .iter()
149 .zip(values.iter())
150 .try_for_each(|(field, values)| {
151 encode_dictionary(
152 field,
153 values.as_ref(),
154 options,
155 dictionary_tracker,
156 encoded_dictionaries,
157 )
158 })
159 }
160 Map => {
161 let values = array.as_any().downcast_ref::<MapArray>().unwrap().field();
162 let field = &field.fields[0]; encode_dictionary(
164 field,
165 values.as_ref(),
166 options,
167 dictionary_tracker,
168 encoded_dictionaries,
169 )
170 }
171 }
172}
173
174pub fn encode_chunk(
175 chunk: &Chunk<Box<dyn Array>>,
176 fields: &[IpcField],
177 dictionary_tracker: &mut DictionaryTracker,
178 options: &WriteOptions,
179) -> Result<(Vec<EncodedData>, EncodedData)> {
180 let mut encoded_message = EncodedData::default();
181 let encoded_dictionaries = encode_chunk_amortized(
182 chunk,
183 fields,
184 dictionary_tracker,
185 options,
186 &mut encoded_message,
187 )?;
188 Ok((encoded_dictionaries, encoded_message))
189}
190
191pub fn encode_chunk_amortized(
193 chunk: &Chunk<Box<dyn Array>>,
194 fields: &[IpcField],
195 dictionary_tracker: &mut DictionaryTracker,
196 options: &WriteOptions,
197 encoded_message: &mut EncodedData,
198) -> Result<Vec<EncodedData>> {
199 let mut encoded_dictionaries = vec![];
200
201 for (field, array) in fields.iter().zip(chunk.as_ref()) {
202 encode_dictionary(
203 field,
204 array.as_ref(),
205 options,
206 dictionary_tracker,
207 &mut encoded_dictionaries,
208 )?;
209 }
210
211 chunk_to_bytes_amortized(chunk, options, encoded_message);
212
213 Ok(encoded_dictionaries)
214}
215
216fn serialize_compression(
217 compression: Option<Compression>,
218) -> Option<Box<arrow_format::ipc::BodyCompression>> {
219 if let Some(compression) = compression {
220 let codec = match compression {
221 Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame,
222 Compression::ZSTD => arrow_format::ipc::CompressionType::Zstd,
223 };
224 Some(Box::new(arrow_format::ipc::BodyCompression {
225 codec,
226 method: arrow_format::ipc::BodyCompressionMethod::Buffer,
227 }))
228 } else {
229 None
230 }
231}
232
233fn chunk_to_bytes_amortized(
236 chunk: &Chunk<Box<dyn Array>>,
237 options: &WriteOptions,
238 encoded_message: &mut EncodedData,
239) {
240 let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
241 let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
242 let mut arrow_data = std::mem::take(&mut encoded_message.arrow_data);
243 arrow_data.clear();
244
245 let mut offset = 0;
246 for array in chunk.arrays() {
247 write(
248 array.as_ref(),
249 &mut buffers,
250 &mut arrow_data,
251 &mut nodes,
252 &mut offset,
253 is_native_little_endian(),
254 options.compression,
255 )
256 }
257
258 let compression = serialize_compression(options.compression);
259
260 let message = arrow_format::ipc::Message {
261 version: arrow_format::ipc::MetadataVersion::V5,
262 header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new(
263 arrow_format::ipc::RecordBatch {
264 length: chunk.len() as i64,
265 nodes: Some(nodes),
266 buffers: Some(buffers),
267 compression,
268 },
269 ))),
270 body_length: arrow_data.len() as i64,
271 custom_metadata: None,
272 };
273
274 let mut builder = Builder::new();
275 let ipc_message = builder.finish(&message, None);
276 encoded_message.ipc_message = ipc_message.to_vec();
277 encoded_message.arrow_data = arrow_data
278}
279
280fn dictionary_batch_to_bytes<K: DictionaryKey>(
283 dict_id: i64,
284 array: &DictionaryArray<K>,
285 options: &WriteOptions,
286 is_little_endian: bool,
287) -> EncodedData {
288 let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
289 let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
290 let mut arrow_data: Vec<u8> = vec![];
291
292 let length = write_dictionary(
293 array,
294 &mut buffers,
295 &mut arrow_data,
296 &mut nodes,
297 &mut 0,
298 is_little_endian,
299 options.compression,
300 false,
301 );
302
303 let compression = serialize_compression(options.compression);
304
305 let message = arrow_format::ipc::Message {
306 version: arrow_format::ipc::MetadataVersion::V5,
307 header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new(
308 arrow_format::ipc::DictionaryBatch {
309 id: dict_id,
310 data: Some(Box::new(arrow_format::ipc::RecordBatch {
311 length: length as i64,
312 nodes: Some(nodes),
313 buffers: Some(buffers),
314 compression,
315 })),
316 is_delta: false,
317 },
318 ))),
319 body_length: arrow_data.len() as i64,
320 custom_metadata: None,
321 };
322
323 let mut builder = Builder::new();
324 let ipc_message = builder.finish(&message, None);
325
326 EncodedData {
327 ipc_message: ipc_message.to_vec(),
328 arrow_data,
329 }
330}
331
332pub struct DictionaryTracker {
336 pub dictionaries: Dictionaries,
337 pub cannot_replace: bool,
338}
339
340impl DictionaryTracker {
341 pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> Result<bool> {
351 let values = match array.data_type() {
352 DataType::Dictionary(key_type, _, _) => {
353 match_integer_type!(key_type, |$T| {
354 let array = array
355 .as_any()
356 .downcast_ref::<DictionaryArray<$T>>()
357 .unwrap();
358 array.values()
359 })
360 }
361 _ => unreachable!(),
362 };
363
364 if let Some(last) = self.dictionaries.get(&dict_id) {
366 if last.as_ref() == values.as_ref() {
367 return Ok(false);
369 } else if self.cannot_replace {
370 return Err(Error::InvalidArgumentError(
371 "Dictionary replacement detected when writing IPC file format. \
372 Arrow IPC files only support a single dictionary for a given field \
373 across all batches."
374 .to_string(),
375 ));
376 }
377 };
378
379 self.dictionaries.insert(dict_id, values.clone());
380 Ok(true)
381 }
382}
383
384#[derive(Debug, Default)]
386pub struct EncodedData {
387 pub ipc_message: Vec<u8>,
389 pub arrow_data: Vec<u8>,
391}
392
393#[inline]
395pub(crate) fn pad_to_64(len: usize) -> usize {
396 ((len + 63) & !63) - len
397}
398
399#[derive(Debug, Clone, PartialEq)]
401pub struct Record<'a> {
402 columns: Cow<'a, Chunk<Box<dyn Array>>>,
403 fields: Option<Cow<'a, [IpcField]>>,
404}
405
406impl<'a> Record<'a> {
407 pub fn fields(&self) -> Option<&[IpcField]> {
409 self.fields.as_deref()
410 }
411
412 pub fn columns(&self) -> &Chunk<Box<dyn Array>> {
414 self.columns.borrow()
415 }
416}
417
418impl From<Chunk<Box<dyn Array>>> for Record<'static> {
419 fn from(columns: Chunk<Box<dyn Array>>) -> Self {
420 Self {
421 columns: Cow::Owned(columns),
422 fields: None,
423 }
424 }
425}
426
427impl<'a, F> From<(Chunk<Box<dyn Array>>, Option<F>)> for Record<'a>
428where
429 F: Into<Cow<'a, [IpcField]>>,
430{
431 fn from((columns, fields): (Chunk<Box<dyn Array>>, Option<F>)) -> Self {
432 Self {
433 columns: Cow::Owned(columns),
434 fields: fields.map(|f| f.into()),
435 }
436 }
437}
438
439impl<'a, F> From<(&'a Chunk<Box<dyn Array>>, Option<F>)> for Record<'a>
440where
441 F: Into<Cow<'a, [IpcField]>>,
442{
443 fn from((columns, fields): (&'a Chunk<Box<dyn Array>>, Option<F>)) -> Self {
444 Self {
445 columns: Cow::Borrowed(columns),
446 fields: fields.map(|f| f.into()),
447 }
448 }
449}