Skip to main content

polars_row/
decode.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::bitmap::{Bitmap, BitmapBuilder};
3use arrow::datatypes::ArrowDataType;
4use arrow::offset::OffsetsBuffer;
5use arrow::types::NativeType;
6use polars_buffer::Buffer;
7use polars_dtype::categorical::CatNative;
8
9use self::encode::fixed_size;
10use self::row::{RowEncodingCategoricalContext, RowEncodingOptions};
11use self::variable::utf8::decode_str;
12use super::*;
13use crate::fixed::numeric::{FixedLengthEncoding, FromSlice};
14use crate::fixed::{boolean, decimal, numeric};
15use crate::variable::{binary, no_order, utf8};
16
17/// Decode `rows` into a arrow format
18/// # Safety
19/// This will not do any bound checks. Caller must ensure the `rows` are valid
20/// encodings.
21pub unsafe fn decode_rows_from_binary<'a>(
22    arr: &'a BinaryArray<i64>,
23    opts: &[RowEncodingOptions],
24    dicts: &[Option<RowEncodingContext>],
25    dtypes: &[ArrowDataType],
26    rows: &mut Vec<&'a [u8]>,
27) -> Vec<ArrayRef> {
28    assert_eq!(arr.null_count(), 0);
29    rows.clear();
30    rows.extend(arr.values_iter());
31    decode_rows(rows, opts, dicts, dtypes)
32}
33
34/// Decode `rows` into a arrow format
35/// # Safety
36/// This will not do any bound checks. Caller must ensure the `rows` are valid
37/// encodings.
38pub unsafe fn decode_rows(
39    // the rows will be updated while the data is decoded
40    rows: &mut [&[u8]],
41    opts: &[RowEncodingOptions],
42    dicts: &[Option<RowEncodingContext>],
43    dtypes: &[ArrowDataType],
44) -> Vec<ArrayRef> {
45    assert_eq!(opts.len(), dtypes.len());
46    assert_eq!(dicts.len(), dtypes.len());
47
48    dtypes
49        .iter()
50        .zip(opts)
51        .zip(dicts)
52        .map(|((dtype, opt), dict)| decode(rows, *opt, dict.as_ref(), dtype))
53        .collect()
54}
55
56unsafe fn decode_validity(rows: &mut [&[u8]], opt: RowEncodingOptions) -> Option<Bitmap> {
57    // 2 loop system to avoid the overhead of allocating the bitmap if all the elements are valid.
58
59    let null_sentinel = opt.null_sentinel();
60    let first_null = (0..rows.len()).find(|&i| {
61        let v;
62        (v, rows[i]) = rows[i].split_at_unchecked(1);
63        v[0] == null_sentinel
64    });
65
66    // No nulls just return None
67    let first_null = first_null?;
68
69    let mut bm = BitmapBuilder::new();
70    bm.reserve(rows.len());
71    bm.extend_constant(first_null, true);
72    bm.push(false);
73    bm.extend_trusted_len_iter(rows[first_null + 1..].iter_mut().map(|row| {
74        let v;
75        (v, *row) = row.split_at_unchecked(1);
76        v[0] != null_sentinel
77    }));
78    bm.into_opt_validity()
79}
80
81// We inline this in an attempt to avoid the dispatch cost.
82#[inline(always)]
83fn dtype_and_data_to_encoded_item_len(
84    dtype: &ArrowDataType,
85    data: &[u8],
86    opt: RowEncodingOptions,
87    dict: Option<&RowEncodingContext>,
88) -> usize {
89    // Fast path: if the size is fixed, we can just divide.
90    if let Some(size) = fixed_size(dtype, opt, dict) {
91        return size;
92    }
93
94    use ArrowDataType as D;
95    match dtype {
96        D::Binary | D::LargeBinary | D::BinaryView | D::Utf8 | D::LargeUtf8 | D::Utf8View
97            if opt.contains(RowEncodingOptions::NO_ORDER) =>
98        unsafe { no_order::len_from_buffer(data, opt) },
99        D::Binary | D::LargeBinary | D::BinaryView => unsafe {
100            binary::encoded_item_len(data, opt)
101        },
102        D::Utf8 | D::LargeUtf8 | D::Utf8View => unsafe { utf8::len_from_buffer(data, opt) },
103
104        D::List(list_field) | D::LargeList(list_field) => {
105            let mut data = data;
106            let mut item_len = 0;
107
108            let list_continuation_token = opt.list_continuation_token();
109
110            while data[0] == list_continuation_token {
111                data = &data[1..];
112                let len = dtype_and_data_to_encoded_item_len(list_field.dtype(), data, opt, dict);
113                data = &data[len..];
114                item_len += 1 + len;
115            }
116            1 + item_len
117        },
118
119        D::FixedSizeBinary(_) => todo!(),
120        D::FixedSizeList(fsl_field, width) => {
121            let mut data = &data[1..];
122            let mut item_len = 1; // validity byte
123
124            for _ in 0..*width {
125                let len = dtype_and_data_to_encoded_item_len(
126                    fsl_field.dtype(),
127                    data,
128                    opt.into_nested(),
129                    dict,
130                );
131                data = &data[len..];
132                item_len += len;
133            }
134            item_len
135        },
136        D::Struct(struct_fields) => {
137            let mut data = &data[1..];
138            let mut item_len = 1; // validity byte
139
140            for struct_field in struct_fields {
141                let len = dtype_and_data_to_encoded_item_len(
142                    struct_field.dtype(),
143                    data,
144                    opt.into_nested(),
145                    dict,
146                );
147                data = &data[len..];
148                item_len += len;
149            }
150            item_len
151        },
152
153        D::Union(_) => todo!(),
154        D::Map(_, _) => todo!(),
155        D::Decimal32(_, _) => todo!(),
156        D::Decimal64(_, _) => todo!(),
157        D::Decimal256(_, _) => todo!(),
158        D::Extension(_) => todo!(),
159        D::Unknown => todo!(),
160
161        _ => unreachable!(),
162    }
163}
164
165fn rows_for_fixed_size_list<'a>(
166    dtype: &ArrowDataType,
167    opt: RowEncodingOptions,
168    dict: Option<&RowEncodingContext>,
169    width: usize,
170    rows: &mut [&'a [u8]],
171    nested_rows: &mut Vec<&'a [u8]>,
172) {
173    nested_rows.clear();
174    nested_rows.reserve(rows.len() * width);
175
176    // Fast path: if the size is fixed, we can just divide.
177    if let Some(size) = fixed_size(dtype, opt, dict) {
178        for row in rows.iter_mut() {
179            for i in 0..width {
180                nested_rows.push(&row[(i * size)..][..size]);
181            }
182            *row = &row[size * width..];
183        }
184        return;
185    }
186
187    // @TODO: This is quite slow since we need to dispatch for possibly every nested type
188    for row in rows.iter_mut() {
189        for _ in 0..width {
190            let length = dtype_and_data_to_encoded_item_len(dtype, row, opt.into_nested(), dict);
191            let v;
192            (v, *row) = row.split_at(length);
193            nested_rows.push(v);
194        }
195    }
196}
197
198unsafe fn decode_cat<T: NativeType + FixedLengthEncoding + CatNative>(
199    rows: &mut [&[u8]],
200    opt: RowEncodingOptions,
201    ctx: &RowEncodingCategoricalContext,
202) -> PrimitiveArray<T>
203where
204    T::Encoded: FromSlice,
205{
206    if ctx.is_enum || !opt.is_ordered() {
207        numeric::decode_primitive::<T>(rows, opt)
208    } else {
209        variable::utf8::decode_str_as_cat::<T>(rows, opt, &ctx.mapping)
210    }
211}
212
213unsafe fn decode(
214    rows: &mut [&[u8]],
215    opt: RowEncodingOptions,
216    dict: Option<&RowEncodingContext>,
217    dtype: &ArrowDataType,
218) -> ArrayRef {
219    use ArrowDataType as D;
220
221    if let Some(RowEncodingContext::Categorical(ctx)) = dict {
222        match dtype {
223            D::UInt8 => return decode_cat::<u8>(rows, opt, ctx).to_boxed(),
224            D::UInt16 => return decode_cat::<u16>(rows, opt, ctx).to_boxed(),
225            D::UInt32 => return decode_cat::<u32>(rows, opt, ctx).to_boxed(),
226            D::FixedSizeList(..) | D::List(_) | D::LargeList(_) => {
227                // Nested type, handled below.
228            },
229            _ => unreachable!(),
230        };
231    }
232
233    match dtype {
234        D::Null => NullArray::new(D::Null, rows.len()).to_boxed(),
235        D::Boolean => boolean::decode_bool(rows, opt).to_boxed(),
236        D::Binary | D::LargeBinary | D::BinaryView | D::Utf8 | D::LargeUtf8 | D::Utf8View
237            if opt.contains(RowEncodingOptions::NO_ORDER) =>
238        {
239            let array = no_order::decode_variable_no_order(rows, opt);
240
241            if matches!(dtype, D::Utf8 | D::LargeUtf8 | D::Utf8View) {
242                unsafe { array.to_utf8view_unchecked() }.to_boxed()
243            } else {
244                array.to_boxed()
245            }
246        },
247        D::Binary | D::LargeBinary | D::BinaryView => binary::decode_binview(rows, opt).to_boxed(),
248        D::Utf8 | D::LargeUtf8 | D::Utf8View => decode_str(rows, opt).boxed(),
249
250        D::Struct(fields) => {
251            let validity = decode_validity(rows, opt);
252
253            let values = match dict {
254                None => fields
255                    .iter()
256                    .map(|struct_fld| decode(rows, opt.into_nested(), None, struct_fld.dtype()))
257                    .collect(),
258                Some(RowEncodingContext::Struct(dicts)) => fields
259                    .iter()
260                    .zip(dicts)
261                    .map(|(struct_fld, dict)| {
262                        decode(rows, opt.into_nested(), dict.as_ref(), struct_fld.dtype())
263                    })
264                    .collect(),
265                _ => unreachable!(),
266            };
267            StructArray::new(dtype.clone(), rows.len(), values, validity).to_boxed()
268        },
269        D::FixedSizeList(fsl_field, width) => {
270            let validity = decode_validity(rows, opt);
271
272            // @TODO: we could consider making this into a scratchpad
273            let mut nested_rows = Vec::new();
274            rows_for_fixed_size_list(
275                fsl_field.dtype(),
276                opt.into_nested(),
277                dict,
278                *width,
279                rows,
280                &mut nested_rows,
281            );
282
283            let values = decode(&mut nested_rows, opt.into_nested(), dict, fsl_field.dtype());
284
285            FixedSizeListArray::new(dtype.clone(), rows.len(), values, validity).to_boxed()
286        },
287        D::List(list_field) | D::LargeList(list_field) => {
288            let mut validity = BitmapBuilder::new();
289
290            // @TODO: we could consider making this into a scratchpad
291            let num_rows = rows.len();
292            let mut nested_rows = Vec::new();
293            let mut offsets = Vec::with_capacity(rows.len() + 1);
294            offsets.push(0);
295
296            let list_null_sentinel = opt.list_null_sentinel();
297            let list_continuation_token = opt.list_continuation_token();
298            let list_termination_token = opt.list_termination_token();
299
300            // @TODO: make a specialized loop for fixed size list_field.dtype()
301            for (i, row) in rows.iter_mut().enumerate() {
302                while row[0] == list_continuation_token {
303                    *row = &row[1..];
304                    let len = dtype_and_data_to_encoded_item_len(
305                        list_field.dtype(),
306                        row,
307                        opt.into_nested(),
308                        dict,
309                    );
310                    nested_rows.push(&row[..len]);
311                    *row = &row[len..];
312                }
313
314                offsets.push(nested_rows.len() as i64);
315
316                // @TODO: Might be better to make this a 2-loop system.
317                if row[0] == list_null_sentinel {
318                    *row = &row[1..];
319                    validity.reserve(num_rows);
320                    validity.extend_constant(i - validity.len(), true);
321                    validity.push(false);
322                    continue;
323                }
324
325                assert_eq!(row[0], list_termination_token);
326                *row = &row[1..];
327            }
328
329            let validity = if validity.is_empty() {
330                None
331            } else {
332                validity.extend_constant(num_rows - validity.len(), true);
333                validity.into_opt_validity()
334            };
335            assert_eq!(offsets.len(), rows.len() + 1);
336
337            let values = decode(
338                &mut nested_rows,
339                opt.into_nested(),
340                dict,
341                list_field.dtype(),
342            );
343
344            ListArray::<i64>::new(
345                dtype.clone(),
346                unsafe { OffsetsBuffer::new_unchecked(Buffer::from(offsets)) },
347                values,
348                validity,
349            )
350            .to_boxed()
351        },
352
353        dt => {
354            if matches!(dt, D::Int128) {
355                if let Some(dict) = dict {
356                    return match dict {
357                        RowEncodingContext::Decimal(precision) => {
358                            decimal::decode(rows, opt, *precision).to_boxed()
359                        },
360                        _ => unreachable!(),
361                    };
362                }
363            }
364
365            with_match_arrow_primitive_type!(dt, |$T| {
366                numeric::decode_primitive::<$T>(rows, opt).to_boxed()
367            })
368        },
369    }
370}