Skip to main content

polars_arrow/io/avro/write/
serialize.rs

1use avro_schema::schema::{Record, Schema as AvroSchema};
2use avro_schema::write::encode;
3
4use super::super::super::iterator::*;
5use crate::array::*;
6use crate::bitmap::utils::ZipValidity;
7use crate::datatypes::{ArrowDataType, IntervalUnit, PhysicalType, PrimitiveType};
8use crate::offset::Offset;
9use crate::types::months_days_ns;
10
11// Zigzag representation of false and true respectively.
12const IS_NULL: u8 = 0;
13const IS_VALID: u8 = 2;
14
15/// A type alias for a boxed [`StreamingIterator`], used to write arrays into avro rows
16/// (i.e. a column -> row transposition of types known at run-time)
17pub type BoxSerializer<'a> = Box<dyn StreamingIterator<Item = [u8]> + 'a + Send + Sync>;
18
19fn utf8_required<O: Offset>(array: &Utf8Array<O>) -> BoxSerializer<'_> {
20    Box::new(BufStreamingIterator::new(
21        array.values_iter(),
22        |x, buf| {
23            encode::zigzag_encode(x.len() as i64, buf).unwrap();
24            buf.extend_from_slice(x.as_bytes());
25        },
26        vec![],
27    ))
28}
29
30fn utf8_optional<O: Offset>(array: &Utf8Array<O>) -> BoxSerializer<'_> {
31    Box::new(BufStreamingIterator::new(
32        array.iter(),
33        |x, buf| {
34            if let Some(x) = x {
35                buf.push(IS_VALID);
36                encode::zigzag_encode(x.len() as i64, buf).unwrap();
37                buf.extend_from_slice(x.as_bytes());
38            } else {
39                buf.push(IS_NULL);
40            }
41        },
42        vec![],
43    ))
44}
45
46fn binary_required<O: Offset>(array: &BinaryArray<O>) -> BoxSerializer<'_> {
47    Box::new(BufStreamingIterator::new(
48        array.values_iter(),
49        |x, buf| {
50            encode::zigzag_encode(x.len() as i64, buf).unwrap();
51            buf.extend_from_slice(x);
52        },
53        vec![],
54    ))
55}
56
57fn binary_optional<O: Offset>(array: &BinaryArray<O>) -> BoxSerializer<'_> {
58    Box::new(BufStreamingIterator::new(
59        array.iter(),
60        |x, buf| {
61            if let Some(x) = x {
62                buf.push(IS_VALID);
63                encode::zigzag_encode(x.len() as i64, buf).unwrap();
64                buf.extend_from_slice(x);
65            } else {
66                buf.push(IS_NULL);
67            }
68        },
69        vec![],
70    ))
71}
72
73fn fixed_size_binary_required(array: &FixedSizeBinaryArray) -> BoxSerializer<'_> {
74    Box::new(BufStreamingIterator::new(
75        array.values_iter(),
76        |x, buf| {
77            buf.extend_from_slice(x);
78        },
79        vec![],
80    ))
81}
82
83fn fixed_size_binary_optional(array: &FixedSizeBinaryArray) -> BoxSerializer<'_> {
84    Box::new(BufStreamingIterator::new(
85        array.iter(),
86        |x, buf| {
87            if let Some(x) = x {
88                buf.push(IS_VALID);
89                buf.extend_from_slice(x);
90            } else {
91                buf.push(IS_NULL);
92            }
93        },
94        vec![],
95    ))
96}
97
98fn list_required<'a, O: Offset>(array: &'a ListArray<O>, schema: &AvroSchema) -> BoxSerializer<'a> {
99    let mut inner = new_serializer(array.values().as_ref(), schema);
100    let lengths = array
101        .offsets()
102        .buffer()
103        .windows(2)
104        .map(|w| (w[1] - w[0]).to_usize() as i64);
105
106    Box::new(BufStreamingIterator::new(
107        lengths,
108        move |length, buf| {
109            encode::zigzag_encode(length, buf).unwrap();
110            let mut rows = 0;
111            while let Some(item) = inner.next() {
112                buf.extend_from_slice(item);
113                rows += 1;
114                if rows == length {
115                    encode::zigzag_encode(0, buf).unwrap();
116                    break;
117                }
118            }
119        },
120        vec![],
121    ))
122}
123
124fn list_optional<'a, O: Offset>(array: &'a ListArray<O>, schema: &AvroSchema) -> BoxSerializer<'a> {
125    let mut inner = new_serializer(array.values().as_ref(), schema);
126    let lengths = array
127        .offsets()
128        .buffer()
129        .windows(2)
130        .map(|w| (w[1] - w[0]).to_usize() as i64);
131    let lengths = ZipValidity::new_with_validity(lengths, array.validity());
132
133    Box::new(BufStreamingIterator::new(
134        lengths,
135        move |length, buf| {
136            if let Some(length) = length {
137                buf.push(IS_VALID);
138                encode::zigzag_encode(length, buf).unwrap();
139                let mut rows = 0;
140                while let Some(item) = inner.next() {
141                    buf.extend_from_slice(item);
142                    rows += 1;
143                    if rows == length {
144                        encode::zigzag_encode(0, buf).unwrap();
145                        break;
146                    }
147                }
148            } else {
149                buf.push(IS_NULL);
150            }
151        },
152        vec![],
153    ))
154}
155
156fn struct_required<'a>(array: &'a StructArray, schema: &Record) -> BoxSerializer<'a> {
157    let schemas = schema.fields.iter().map(|x| &x.schema);
158    let mut inner = array
159        .values()
160        .iter()
161        .zip(schemas)
162        .map(|(x, schema)| new_serializer(x.as_ref(), schema))
163        .collect::<Vec<_>>();
164
165    Box::new(BufStreamingIterator::new(
166        0..array.len(),
167        move |_, buf| {
168            inner
169                .iter_mut()
170                .for_each(|item| buf.extend_from_slice(item.next().unwrap()))
171        },
172        vec![],
173    ))
174}
175
176fn struct_optional<'a>(array: &'a StructArray, schema: &Record) -> BoxSerializer<'a> {
177    let schemas = schema.fields.iter().map(|x| &x.schema);
178    let mut inner = array
179        .values()
180        .iter()
181        .zip(schemas)
182        .map(|(x, schema)| new_serializer(x.as_ref(), schema))
183        .collect::<Vec<_>>();
184
185    let iterator = ZipValidity::new_with_validity(0..array.len(), array.validity());
186
187    Box::new(BufStreamingIterator::new(
188        iterator,
189        move |maybe, buf| {
190            if maybe.is_some() {
191                buf.push(IS_VALID);
192                inner
193                    .iter_mut()
194                    .for_each(|item| buf.extend_from_slice(item.next().unwrap()))
195            } else {
196                buf.push(IS_NULL);
197                // skip the item
198                inner.iter_mut().for_each(|item| {
199                    let _ = item.next().unwrap();
200                });
201            }
202        },
203        vec![],
204    ))
205}
206
207/// Creates a [`StreamingIterator`] trait object that presents items from `array`
208/// encoded according to `schema`.
209/// # Panic
210/// This function panics iff the `dtype` is not supported (use [`can_serialize`] to check)
211/// # Implementation
212/// This function performs minimal CPU work: it dynamically dispatches based on the schema
213/// and arrow type.
214pub fn new_serializer<'a>(array: &'a dyn Array, schema: &AvroSchema) -> BoxSerializer<'a> {
215    let dtype = array.dtype().to_physical_type();
216
217    match (dtype, schema) {
218        (PhysicalType::Boolean, AvroSchema::Boolean) => {
219            let values = array.as_any().downcast_ref::<BooleanArray>().unwrap();
220            Box::new(BufStreamingIterator::new(
221                values.values_iter(),
222                |x, buf| {
223                    buf.push(x as u8);
224                },
225                vec![],
226            ))
227        },
228        (PhysicalType::Boolean, AvroSchema::Union(_)) => {
229            let values = array.as_any().downcast_ref::<BooleanArray>().unwrap();
230            Box::new(BufStreamingIterator::new(
231                values.iter(),
232                |x, buf| {
233                    if let Some(x) = x {
234                        buf.extend_from_slice(&[IS_VALID, x as u8]);
235                    } else {
236                        buf.push(IS_NULL);
237                    }
238                },
239                vec![],
240            ))
241        },
242        (PhysicalType::Utf8, AvroSchema::Union(_)) => {
243            utf8_optional::<i32>(array.as_any().downcast_ref().unwrap())
244        },
245        (PhysicalType::LargeUtf8, AvroSchema::Union(_)) => {
246            utf8_optional::<i64>(array.as_any().downcast_ref().unwrap())
247        },
248        (PhysicalType::Utf8, AvroSchema::String(_)) => {
249            utf8_required::<i32>(array.as_any().downcast_ref().unwrap())
250        },
251        (PhysicalType::LargeUtf8, AvroSchema::String(_)) => {
252            utf8_required::<i64>(array.as_any().downcast_ref().unwrap())
253        },
254        (PhysicalType::Binary, AvroSchema::Union(_)) => {
255            binary_optional::<i32>(array.as_any().downcast_ref().unwrap())
256        },
257        (PhysicalType::LargeBinary, AvroSchema::Union(_)) => {
258            binary_optional::<i64>(array.as_any().downcast_ref().unwrap())
259        },
260        (PhysicalType::FixedSizeBinary, AvroSchema::Union(_)) => {
261            fixed_size_binary_optional(array.as_any().downcast_ref().unwrap())
262        },
263        (PhysicalType::Binary, AvroSchema::Bytes(_)) => {
264            binary_required::<i32>(array.as_any().downcast_ref().unwrap())
265        },
266        (PhysicalType::LargeBinary, AvroSchema::Bytes(_)) => {
267            binary_required::<i64>(array.as_any().downcast_ref().unwrap())
268        },
269        (PhysicalType::FixedSizeBinary, AvroSchema::Fixed(_)) => {
270            fixed_size_binary_required(array.as_any().downcast_ref().unwrap())
271        },
272
273        (PhysicalType::Primitive(PrimitiveType::Int32), AvroSchema::Union(_)) => {
274            let values = array
275                .as_any()
276                .downcast_ref::<PrimitiveArray<i32>>()
277                .unwrap();
278            Box::new(BufStreamingIterator::new(
279                values.iter(),
280                |x, buf| {
281                    if let Some(x) = x {
282                        buf.push(IS_VALID);
283                        encode::zigzag_encode(*x as i64, buf).unwrap();
284                    } else {
285                        buf.push(IS_NULL);
286                    }
287                },
288                vec![],
289            ))
290        },
291        (PhysicalType::Primitive(PrimitiveType::Int32), AvroSchema::Int(_)) => {
292            let values = array
293                .as_any()
294                .downcast_ref::<PrimitiveArray<i32>>()
295                .unwrap();
296            Box::new(BufStreamingIterator::new(
297                values.values().iter(),
298                |x, buf| {
299                    encode::zigzag_encode(*x as i64, buf).unwrap();
300                },
301                vec![],
302            ))
303        },
304        (PhysicalType::Primitive(PrimitiveType::Int64), AvroSchema::Union(_)) => {
305            let values = array
306                .as_any()
307                .downcast_ref::<PrimitiveArray<i64>>()
308                .unwrap();
309            Box::new(BufStreamingIterator::new(
310                values.iter(),
311                |x, buf| {
312                    if let Some(x) = x {
313                        buf.push(IS_VALID);
314                        encode::zigzag_encode(*x, buf).unwrap();
315                    } else {
316                        buf.push(IS_NULL);
317                    }
318                },
319                vec![],
320            ))
321        },
322        (PhysicalType::Primitive(PrimitiveType::Int64), AvroSchema::Long(_)) => {
323            let values = array
324                .as_any()
325                .downcast_ref::<PrimitiveArray<i64>>()
326                .unwrap();
327            Box::new(BufStreamingIterator::new(
328                values.values().iter(),
329                |x, buf| {
330                    encode::zigzag_encode(*x, buf).unwrap();
331                },
332                vec![],
333            ))
334        },
335        (PhysicalType::Primitive(PrimitiveType::Float32), AvroSchema::Union(_)) => {
336            let values = array
337                .as_any()
338                .downcast_ref::<PrimitiveArray<f32>>()
339                .unwrap();
340            Box::new(BufStreamingIterator::new(
341                values.iter(),
342                |x, buf| {
343                    if let Some(x) = x {
344                        buf.push(IS_VALID);
345                        buf.extend(x.to_le_bytes())
346                    } else {
347                        buf.push(IS_NULL);
348                    }
349                },
350                vec![],
351            ))
352        },
353        (PhysicalType::Primitive(PrimitiveType::Float32), AvroSchema::Float) => {
354            let values = array
355                .as_any()
356                .downcast_ref::<PrimitiveArray<f32>>()
357                .unwrap();
358            Box::new(BufStreamingIterator::new(
359                values.values().iter(),
360                |x, buf| {
361                    buf.extend_from_slice(&x.to_le_bytes());
362                },
363                vec![],
364            ))
365        },
366        (PhysicalType::Primitive(PrimitiveType::Float64), AvroSchema::Union(_)) => {
367            let values = array
368                .as_any()
369                .downcast_ref::<PrimitiveArray<f64>>()
370                .unwrap();
371            Box::new(BufStreamingIterator::new(
372                values.iter(),
373                |x, buf| {
374                    if let Some(x) = x {
375                        buf.push(IS_VALID);
376                        buf.extend(x.to_le_bytes())
377                    } else {
378                        buf.push(IS_NULL);
379                    }
380                },
381                vec![],
382            ))
383        },
384        (PhysicalType::Primitive(PrimitiveType::Float64), AvroSchema::Double) => {
385            let values = array
386                .as_any()
387                .downcast_ref::<PrimitiveArray<f64>>()
388                .unwrap();
389            Box::new(BufStreamingIterator::new(
390                values.values().iter(),
391                |x, buf| {
392                    buf.extend_from_slice(&x.to_le_bytes());
393                },
394                vec![],
395            ))
396        },
397        (PhysicalType::Primitive(PrimitiveType::Int128), AvroSchema::Bytes(_)) => {
398            let values = array
399                .as_any()
400                .downcast_ref::<PrimitiveArray<i128>>()
401                .unwrap();
402            Box::new(BufStreamingIterator::new(
403                values.values().iter(),
404                |x, buf| {
405                    let len = ((x.leading_zeros() / 8) - ((x.leading_zeros() / 8) % 2)) as usize;
406                    encode::zigzag_encode((16 - len) as i64, buf).unwrap();
407                    buf.extend_from_slice(&x.to_be_bytes()[len..]);
408                },
409                vec![],
410            ))
411        },
412        (PhysicalType::Primitive(PrimitiveType::Int128), AvroSchema::Union(_)) => {
413            let values = array
414                .as_any()
415                .downcast_ref::<PrimitiveArray<i128>>()
416                .unwrap();
417            Box::new(BufStreamingIterator::new(
418                values.iter(),
419                |x, buf| {
420                    if let Some(x) = x {
421                        buf.push(IS_VALID);
422                        let len =
423                            ((x.leading_zeros() / 8) - ((x.leading_zeros() / 8) % 2)) as usize;
424                        encode::zigzag_encode((16 - len) as i64, buf).unwrap();
425                        buf.extend_from_slice(&x.to_be_bytes()[len..]);
426                    } else {
427                        buf.push(IS_NULL);
428                    }
429                },
430                vec![],
431            ))
432        },
433        (PhysicalType::Primitive(PrimitiveType::MonthDayNano), AvroSchema::Fixed(_)) => {
434            let values = array
435                .as_any()
436                .downcast_ref::<PrimitiveArray<months_days_ns>>()
437                .unwrap();
438            Box::new(BufStreamingIterator::new(
439                values.values().iter(),
440                interval_write,
441                vec![],
442            ))
443        },
444        (PhysicalType::Primitive(PrimitiveType::MonthDayNano), AvroSchema::Union(_)) => {
445            let values = array
446                .as_any()
447                .downcast_ref::<PrimitiveArray<months_days_ns>>()
448                .unwrap();
449            Box::new(BufStreamingIterator::new(
450                values.iter(),
451                |x, buf| {
452                    if let Some(x) = x {
453                        buf.push(IS_VALID);
454                        interval_write(x, buf)
455                    } else {
456                        buf.push(IS_NULL);
457                    }
458                },
459                vec![],
460            ))
461        },
462
463        (PhysicalType::List, AvroSchema::Array(schema)) => {
464            list_required::<i32>(array.as_any().downcast_ref().unwrap(), schema.as_ref())
465        },
466        (PhysicalType::LargeList, AvroSchema::Array(schema)) => {
467            list_required::<i64>(array.as_any().downcast_ref().unwrap(), schema.as_ref())
468        },
469        (PhysicalType::List, AvroSchema::Union(inner)) => {
470            let schema = if let AvroSchema::Array(schema) = &inner[1] {
471                schema.as_ref()
472            } else {
473                unreachable!("The schema declaration does not match the deserialization")
474            };
475            list_optional::<i32>(array.as_any().downcast_ref().unwrap(), schema)
476        },
477        (PhysicalType::LargeList, AvroSchema::Union(inner)) => {
478            let schema = if let AvroSchema::Array(schema) = &inner[1] {
479                schema.as_ref()
480            } else {
481                unreachable!("The schema declaration does not match the deserialization")
482            };
483            list_optional::<i64>(array.as_any().downcast_ref().unwrap(), schema)
484        },
485        (PhysicalType::Struct, AvroSchema::Record(inner)) => {
486            struct_required(array.as_any().downcast_ref().unwrap(), inner)
487        },
488        (PhysicalType::Struct, AvroSchema::Union(inner)) => {
489            let inner = if let AvroSchema::Record(inner) = &inner[1] {
490                inner
491            } else {
492                unreachable!("The schema declaration does not match the deserialization")
493            };
494            struct_optional(array.as_any().downcast_ref().unwrap(), inner)
495        },
496        (a, b) => todo!("{:?} -> {:?} not supported", a, b),
497    }
498}
499
500/// Whether [`new_serializer`] supports `dtype`.
501pub fn can_serialize(dtype: &ArrowDataType) -> bool {
502    use ArrowDataType::*;
503    match dtype.to_storage() {
504        List(inner) => return can_serialize(&inner.dtype),
505        LargeList(inner) => return can_serialize(&inner.dtype),
506        Struct(inner) => return inner.iter().all(|inner| can_serialize(&inner.dtype)),
507        _ => {},
508    };
509
510    matches!(
511        dtype,
512        Boolean
513            | Int32
514            | Int64
515            | Float32
516            | Float64
517            | Decimal(_, _)
518            | Utf8
519            | Binary
520            | FixedSizeBinary(_)
521            | LargeUtf8
522            | LargeBinary
523            | Interval(IntervalUnit::MonthDayNano)
524    )
525}
526
527#[inline]
528fn interval_write(x: &months_days_ns, buf: &mut Vec<u8>) {
529    // https://avro.apache.org/docs/current/spec.html#Duration
530    // 12 bytes, months, days, millis in LE
531    buf.reserve(12);
532    buf.extend(x.months().to_le_bytes());
533    buf.extend(x.days().to_le_bytes());
534    buf.extend(((x.ns() / 1_000_000) as i32).to_le_bytes());
535}