polars_arrow/array/dictionary/
mod.rs

1use std::hash::Hash;
2use std::hint::unreachable_unchecked;
3
4use crate::bitmap::Bitmap;
5use crate::bitmap::utils::{BitmapIter, ZipValidity};
6use crate::datatypes::{ArrowDataType, IntegerType};
7use crate::scalar::{Scalar, new_scalar};
8use crate::trusted_len::TrustedLen;
9use crate::types::NativeType;
10
11mod ffi;
12pub(super) mod fmt;
13mod iterator;
14mod mutable;
15use crate::array::specification::check_indexes_unchecked;
16mod typed_iterator;
17mod value_map;
18
19pub use iterator::*;
20pub use mutable::*;
21use polars_error::{PolarsResult, polars_bail};
22
23use super::primitive::PrimitiveArray;
24use super::specification::check_indexes;
25use super::{Array, Splitable, new_empty_array, new_null_array};
26use crate::array::dictionary::typed_iterator::{
27    DictValue, DictionaryIterTyped, DictionaryValuesIterTyped,
28};
29
30/// Trait denoting [`NativeType`]s that can be used as keys of a dictionary.
31/// # Safety
32///
33/// Any implementation of this trait must ensure that `always_fits_usize` only
34/// returns `true` if all values succeeds on `value::try_into::<usize>().unwrap()`.
35pub unsafe trait DictionaryKey: NativeType + TryInto<usize> + TryFrom<usize> + Hash {
36    /// The corresponding [`IntegerType`] of this key
37    const KEY_TYPE: IntegerType;
38    const MAX_USIZE_VALUE: usize;
39
40    /// Represents this key as a `usize`.
41    ///
42    /// # Safety
43    /// The caller _must_ have checked that the value can be cast to `usize`.
44    #[inline]
45    unsafe fn as_usize(self) -> usize {
46        match self.try_into() {
47            Ok(v) => v,
48            Err(_) => unreachable_unchecked(),
49        }
50    }
51
52    /// Create a key from a `usize` without checking bounds.
53    ///
54    /// # Safety
55    /// The caller _must_ have checked that the value can be created from a `usize`.
56    #[inline]
57    unsafe fn from_usize_unchecked(x: usize) -> Self {
58        debug_assert!(Self::try_from(x).is_ok());
59        unsafe { Self::try_from(x).unwrap_unchecked() }
60    }
61
62    /// If the key type always can be converted to `usize`.
63    fn always_fits_usize() -> bool {
64        false
65    }
66}
67
68unsafe impl DictionaryKey for i8 {
69    const KEY_TYPE: IntegerType = IntegerType::Int8;
70    const MAX_USIZE_VALUE: usize = i8::MAX as usize;
71}
72unsafe impl DictionaryKey for i16 {
73    const KEY_TYPE: IntegerType = IntegerType::Int16;
74    const MAX_USIZE_VALUE: usize = i16::MAX as usize;
75}
76unsafe impl DictionaryKey for i32 {
77    const KEY_TYPE: IntegerType = IntegerType::Int32;
78    const MAX_USIZE_VALUE: usize = i32::MAX as usize;
79}
80unsafe impl DictionaryKey for i64 {
81    const KEY_TYPE: IntegerType = IntegerType::Int64;
82    const MAX_USIZE_VALUE: usize = i64::MAX as usize;
83}
84unsafe impl DictionaryKey for i128 {
85    const KEY_TYPE: IntegerType = IntegerType::Int128;
86    const MAX_USIZE_VALUE: usize = i128::MAX as usize;
87}
88unsafe impl DictionaryKey for u8 {
89    const KEY_TYPE: IntegerType = IntegerType::UInt8;
90    const MAX_USIZE_VALUE: usize = u8::MAX as usize;
91
92    fn always_fits_usize() -> bool {
93        true
94    }
95}
96unsafe impl DictionaryKey for u16 {
97    const KEY_TYPE: IntegerType = IntegerType::UInt16;
98    const MAX_USIZE_VALUE: usize = u16::MAX as usize;
99
100    fn always_fits_usize() -> bool {
101        true
102    }
103}
104unsafe impl DictionaryKey for u32 {
105    const KEY_TYPE: IntegerType = IntegerType::UInt32;
106    const MAX_USIZE_VALUE: usize = u32::MAX as usize;
107
108    fn always_fits_usize() -> bool {
109        true
110    }
111}
112unsafe impl DictionaryKey for u64 {
113    const KEY_TYPE: IntegerType = IntegerType::UInt64;
114    const MAX_USIZE_VALUE: usize = u64::MAX as usize;
115
116    #[cfg(target_pointer_width = "64")]
117    fn always_fits_usize() -> bool {
118        true
119    }
120}
121unsafe impl DictionaryKey for u128 {
122    const KEY_TYPE: IntegerType = IntegerType::UInt128;
123    const MAX_USIZE_VALUE: usize = u128::MAX as usize;
124}
125
126/// An [`Array`] whose values are stored as indices. This [`Array`] is useful when the cardinality of
127/// values is low compared to the length of the [`Array`].
128///
129/// # Safety
130/// This struct guarantees that each item of [`DictionaryArray::keys`] is castable to `usize` and
131/// its value is smaller than [`DictionaryArray::values`]`.len()`. In other words, you can safely
132/// use `unchecked` calls to retrieve the values
133#[derive(Clone)]
134pub struct DictionaryArray<K: DictionaryKey> {
135    dtype: ArrowDataType,
136    keys: PrimitiveArray<K>,
137    values: Box<dyn Array>,
138}
139
140fn check_dtype(
141    key_type: IntegerType,
142    dtype: &ArrowDataType,
143    values_dtype: &ArrowDataType,
144) -> PolarsResult<()> {
145    if let ArrowDataType::Dictionary(key, value, _) = dtype.to_logical_type() {
146        if *key != key_type {
147            polars_bail!(ComputeError: "DictionaryArray must be initialized with a DataType::Dictionary whose integer is compatible to its keys")
148        }
149        if value.as_ref().to_logical_type() != values_dtype.to_logical_type() {
150            polars_bail!(ComputeError: "DictionaryArray must be initialized with a DataType::Dictionary whose value is equal to its values")
151        }
152    } else {
153        polars_bail!(ComputeError: "DictionaryArray must be initialized with logical DataType::Dictionary")
154    }
155    Ok(())
156}
157
158impl<K: DictionaryKey> DictionaryArray<K> {
159    /// Returns a new [`DictionaryArray`].
160    /// # Implementation
161    /// This function is `O(N)` where `N` is the length of keys
162    /// # Errors
163    /// This function errors iff
164    /// * the `dtype`'s logical type is not a `DictionaryArray`
165    /// * the `dtype`'s keys is not compatible with `keys`
166    /// * the `dtype`'s values's dtype is not equal with `values.dtype()`
167    /// * any of the keys's values is not represented in `usize` or is `>= values.len()`
168    pub fn try_new(
169        dtype: ArrowDataType,
170        keys: PrimitiveArray<K>,
171        values: Box<dyn Array>,
172    ) -> PolarsResult<Self> {
173        check_dtype(K::KEY_TYPE, &dtype, values.dtype())?;
174
175        if keys.null_count() != keys.len() {
176            if K::always_fits_usize() {
177                // SAFETY: we just checked that conversion to `usize` always
178                // succeeds
179                unsafe { check_indexes_unchecked(keys.values(), values.len()) }?;
180            } else {
181                check_indexes(keys.values(), values.len())?;
182            }
183        }
184
185        Ok(Self {
186            dtype,
187            keys,
188            values,
189        })
190    }
191
192    /// Returns a new [`DictionaryArray`].
193    /// # Implementation
194    /// This function is `O(N)` where `N` is the length of keys
195    /// # Errors
196    /// This function errors iff
197    /// * any of the keys's values is not represented in `usize` or is `>= values.len()`
198    pub fn try_from_keys(keys: PrimitiveArray<K>, values: Box<dyn Array>) -> PolarsResult<Self> {
199        let dtype = Self::default_dtype(values.dtype().clone());
200        Self::try_new(dtype, keys, values)
201    }
202
203    /// Returns a new [`DictionaryArray`].
204    /// # Errors
205    /// This function errors iff
206    /// * the `dtype`'s logical type is not a `DictionaryArray`
207    /// * the `dtype`'s keys is not compatible with `keys`
208    /// * the `dtype`'s values's dtype is not equal with `values.dtype()`
209    ///
210    /// # Safety
211    /// The caller must ensure that every keys's values is represented in `usize` and is `< values.len()`
212    pub unsafe fn try_new_unchecked(
213        dtype: ArrowDataType,
214        keys: PrimitiveArray<K>,
215        values: Box<dyn Array>,
216    ) -> PolarsResult<Self> {
217        check_dtype(K::KEY_TYPE, &dtype, values.dtype())?;
218
219        Ok(Self {
220            dtype,
221            keys,
222            values,
223        })
224    }
225
226    /// Returns a new empty [`DictionaryArray`].
227    pub fn new_empty(dtype: ArrowDataType) -> Self {
228        let values = Self::try_get_child(&dtype).unwrap();
229        let values = new_empty_array(values.clone());
230        Self::try_new(
231            dtype,
232            PrimitiveArray::<K>::new_empty(K::PRIMITIVE.into()),
233            values,
234        )
235        .unwrap()
236    }
237
238    /// Returns an [`DictionaryArray`] whose all elements are null
239    #[inline]
240    pub fn new_null(dtype: ArrowDataType, length: usize) -> Self {
241        let values = Self::try_get_child(&dtype).unwrap();
242        let values = new_null_array(values.clone(), 1);
243        Self::try_new(
244            dtype,
245            PrimitiveArray::<K>::new_null(K::PRIMITIVE.into(), length),
246            values,
247        )
248        .unwrap()
249    }
250
251    /// Returns an iterator of [`Option<Box<dyn Scalar>>`].
252    /// # Implementation
253    /// This function will allocate a new [`Scalar`] per item and is usually not performant.
254    /// Consider calling `keys_iter` and `values`, downcasting `values`, and iterating over that.
255    pub fn iter(
256        &self,
257    ) -> ZipValidity<Box<dyn Scalar>, DictionaryValuesIter<'_, K>, BitmapIter<'_>> {
258        ZipValidity::new_with_validity(DictionaryValuesIter::new(self), self.keys.validity())
259    }
260
261    /// Returns an iterator of [`Box<dyn Scalar>`]
262    /// # Implementation
263    /// This function will allocate a new [`Scalar`] per item and is usually not performant.
264    /// Consider calling `keys_iter` and `values`, downcasting `values`, and iterating over that.
265    pub fn values_iter(&self) -> DictionaryValuesIter<'_, K> {
266        DictionaryValuesIter::new(self)
267    }
268
269    /// Returns an iterator over the values [`V::IterValue`].
270    ///
271    /// # Panics
272    ///
273    /// Panics if the keys of this [`DictionaryArray`] has any nulls.
274    /// If they do [`DictionaryArray::iter_typed`] should be used.
275    pub fn values_iter_typed<V: DictValue>(
276        &self,
277    ) -> PolarsResult<DictionaryValuesIterTyped<'_, K, V>> {
278        let keys = &self.keys;
279        assert_eq!(keys.null_count(), 0);
280        let values = self.values.as_ref();
281        let values = V::downcast_values(values)?;
282        Ok(DictionaryValuesIterTyped::new(keys, values))
283    }
284
285    /// Returns an iterator over the optional values of  [`Option<V::IterValue>`].
286    pub fn iter_typed<V: DictValue>(&self) -> PolarsResult<DictionaryIterTyped<'_, K, V>> {
287        let keys = &self.keys;
288        let values = self.values.as_ref();
289        let values = V::downcast_values(values)?;
290        Ok(DictionaryIterTyped::new(keys, values))
291    }
292
293    /// Returns the [`ArrowDataType`] of this [`DictionaryArray`]
294    #[inline]
295    pub fn dtype(&self) -> &ArrowDataType {
296        &self.dtype
297    }
298
299    /// Returns whether the values of this [`DictionaryArray`] are ordered
300    #[inline]
301    pub fn is_ordered(&self) -> bool {
302        match self.dtype.to_logical_type() {
303            ArrowDataType::Dictionary(_, _, is_ordered) => *is_ordered,
304            _ => unreachable!(),
305        }
306    }
307
308    pub(crate) fn default_dtype(values_datatype: ArrowDataType) -> ArrowDataType {
309        ArrowDataType::Dictionary(K::KEY_TYPE, Box::new(values_datatype), false)
310    }
311
312    /// Slices this [`DictionaryArray`].
313    /// # Panics
314    /// iff `offset + length > self.len()`.
315    pub fn slice(&mut self, offset: usize, length: usize) {
316        self.keys.slice(offset, length);
317    }
318
319    /// Slices this [`DictionaryArray`].
320    ///
321    /// # Safety
322    /// Safe iff `offset + length <= self.len()`.
323    pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) {
324        self.keys.slice_unchecked(offset, length);
325    }
326
327    impl_sliced!();
328
329    /// Returns this [`DictionaryArray`] with a new validity.
330    /// # Panic
331    /// This function panics iff `validity.len() != self.len()`.
332    #[must_use]
333    pub fn with_validity(mut self, validity: Option<Bitmap>) -> Self {
334        self.set_validity(validity);
335        self
336    }
337
338    /// Sets the validity of the keys of this [`DictionaryArray`].
339    /// # Panics
340    /// This function panics iff `validity.len() != self.len()`.
341    pub fn set_validity(&mut self, validity: Option<Bitmap>) {
342        self.keys.set_validity(validity);
343    }
344
345    impl_into_array!();
346
347    /// Returns the length of this array
348    #[inline]
349    pub fn len(&self) -> usize {
350        self.keys.len()
351    }
352
353    /// The optional validity. Equivalent to `self.keys().validity()`.
354    #[inline]
355    pub fn validity(&self) -> Option<&Bitmap> {
356        self.keys.validity()
357    }
358
359    /// Returns the keys of the [`DictionaryArray`]. These keys can be used to fetch values
360    /// from `values`.
361    #[inline]
362    pub fn keys(&self) -> &PrimitiveArray<K> {
363        &self.keys
364    }
365
366    /// Returns an iterator of the keys' values of the [`DictionaryArray`] as `usize`
367    #[inline]
368    pub fn keys_values_iter(&self) -> impl TrustedLen<Item = usize> + Clone + '_ {
369        // SAFETY: invariant of the struct
370        self.keys.values_iter().map(|x| unsafe { x.as_usize() })
371    }
372
373    /// Returns an iterator of the keys' of the [`DictionaryArray`] as `usize`
374    #[inline]
375    pub fn keys_iter(&self) -> impl TrustedLen<Item = Option<usize>> + Clone + '_ {
376        // SAFETY: invariant of the struct
377        self.keys.iter().map(|x| x.map(|x| unsafe { x.as_usize() }))
378    }
379
380    /// Returns the keys' value of the [`DictionaryArray`] as `usize`
381    /// # Panics
382    /// This function panics iff `index >= self.len()`
383    #[inline]
384    pub fn key_value(&self, index: usize) -> usize {
385        // SAFETY: invariant of the struct
386        unsafe { self.keys.values()[index].as_usize() }
387    }
388
389    /// Returns the values of the [`DictionaryArray`].
390    #[inline]
391    pub fn values(&self) -> &Box<dyn Array> {
392        &self.values
393    }
394
395    /// Returns the value of the [`DictionaryArray`] at position `i`.
396    /// # Implementation
397    /// This function will allocate a new [`Scalar`] and is usually not performant.
398    /// Consider calling `keys` and `values`, downcasting `values`, and iterating over that.
399    /// # Panic
400    /// This function panics iff `index >= self.len()`
401    #[inline]
402    pub fn value(&self, index: usize) -> Box<dyn Scalar> {
403        // SAFETY: invariant of this struct
404        let index = unsafe { self.keys.value(index).as_usize() };
405        new_scalar(self.values.as_ref(), index)
406    }
407
408    pub(crate) fn try_get_child(dtype: &ArrowDataType) -> PolarsResult<&ArrowDataType> {
409        Ok(match dtype.to_logical_type() {
410            ArrowDataType::Dictionary(_, values, _) => values.as_ref(),
411            _ => {
412                polars_bail!(ComputeError: "Dictionaries must be initialized with DataType::Dictionary")
413            },
414        })
415    }
416
417    pub fn take(self) -> (ArrowDataType, PrimitiveArray<K>, Box<dyn Array>) {
418        (self.dtype, self.keys, self.values)
419    }
420}
421
422impl<K: DictionaryKey> Array for DictionaryArray<K> {
423    impl_common_array!();
424
425    fn validity(&self) -> Option<&Bitmap> {
426        self.keys.validity()
427    }
428
429    #[inline]
430    fn with_validity(&self, validity: Option<Bitmap>) -> Box<dyn Array> {
431        Box::new(self.clone().with_validity(validity))
432    }
433}
434
435impl<K: DictionaryKey> Splitable for DictionaryArray<K> {
436    fn check_bound(&self, offset: usize) -> bool {
437        offset < self.len()
438    }
439
440    unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) {
441        let (lhs_keys, rhs_keys) = unsafe { Splitable::split_at_unchecked(&self.keys, offset) };
442
443        (
444            Self {
445                dtype: self.dtype.clone(),
446                keys: lhs_keys,
447                values: self.values.clone(),
448            },
449            Self {
450                dtype: self.dtype.clone(),
451                keys: rhs_keys,
452                values: self.values.clone(),
453            },
454        )
455    }
456}