Skip to main content

polars_dtype/categorical/
mapping.rs

1use std::fmt;
2use std::hash::BuildHasher;
3use std::sync::atomic::{AtomicUsize, Ordering};
4
5use arrow::array::builder::StaticArrayBuilder;
6use arrow::array::{Array, MutableUtf8Array, Utf8Array, Utf8ViewArrayBuilder};
7use arrow::datatypes::ArrowDataType;
8use polars_error::{PolarsResult, polars_bail};
9use polars_utils::aliases::{PlFixedStateQuality, PlSeedableRandomStateQuality};
10use polars_utils::parma::raw::RawTable;
11
12use super::CatSize;
13
14pub struct CategoricalMapping {
15    str_to_cat: RawTable<str, CatSize>,
16    cat_to_str_and_hash: boxcar::Vec<(&'static str, u64)>,
17    max_categories: usize,
18    upper_bound: AtomicUsize,
19    hasher: PlSeedableRandomStateQuality,
20}
21
22impl CategoricalMapping {
23    pub(crate) fn new(max_categories: usize) -> Self {
24        Self::with_hasher(max_categories, PlSeedableRandomStateQuality::default())
25    }
26
27    pub fn with_hasher(max_categories: usize, hasher: PlSeedableRandomStateQuality) -> Self {
28        Self {
29            str_to_cat: RawTable::default(),
30            cat_to_str_and_hash: boxcar::Vec::default(),
31            max_categories,
32            upper_bound: AtomicUsize::new(0),
33            hasher,
34        }
35    }
36
37    #[inline(always)]
38    pub fn hasher(&self) -> &PlSeedableRandomStateQuality {
39        &self.hasher
40    }
41
42    pub fn max_categories(&self) -> usize {
43        self.max_categories
44    }
45
46    pub fn set_max_categories(&mut self, max_categories: usize) {
47        assert!(max_categories >= self.num_cats_upper_bound());
48        self.max_categories = max_categories
49    }
50
51    /// Try to convert a string to a categorical id, but don't insert it if it is missing.
52    #[inline(always)]
53    pub fn get_cat(&self, s: &str) -> Option<CatSize> {
54        let hash = self.hasher.hash_one(s);
55        self.get_cat_with_hash(s, hash)
56    }
57
58    /// Same as get_cat, but with the hash pre-computed.
59    #[inline(always)]
60    pub fn get_cat_with_hash(&self, s: &str, hash: u64) -> Option<CatSize> {
61        self.str_to_cat.get(hash, |k| k == s).copied()
62    }
63
64    /// Convert a string to a categorical id.
65    #[inline(always)]
66    pub fn insert_cat(&self, s: &str) -> PolarsResult<CatSize> {
67        let hash = self.hasher.hash_one(s);
68        self.insert_cat_with_hash(s, hash)
69    }
70
71    /// Same as to_cat, but with the hash pre-computed.
72    #[inline(always)]
73    pub fn insert_cat_with_hash(&self, s: &str, hash: u64) -> PolarsResult<CatSize> {
74        self.str_to_cat
75            .try_get_or_insert_with(
76                hash,
77                s,
78                |k| k == s,
79                |k| {
80                    let old_upper_bound = self.upper_bound.fetch_add(1, Ordering::Relaxed);
81                    if old_upper_bound + 1 > self.max_categories {
82                        self.upper_bound.fetch_sub(1, Ordering::Relaxed);
83                        polars_bail!(ComputeError: "attempted to insert more categories than the maximum allowed");
84                    }
85                    let hash = PlFixedStateQuality::default().hash_one(k);
86                    let idx = self
87                        .cat_to_str_and_hash
88                        .push((unsafe { core::mem::transmute::<&str, &'static str>(k) }, hash));
89                    Ok(idx as CatSize)
90                },
91            )
92            .copied()
93    }
94
95    /// Try to convert a categorical id to its corresponding string, returning
96    /// None if the string is not in the data structure.
97    #[inline(always)]
98    pub fn cat_to_str(&self, cat: CatSize) -> Option<&str> {
99        self.cat_to_str_and_hash.get(cat as usize).map(|o| o.0)
100    }
101
102    /// Get the string corresponding to a categorical id.
103    ///
104    /// # Safety
105    /// The categorical id must have been returned from `to_cat`, and you must
106    /// have synchronized with the call which inserted it.
107    #[inline(always)]
108    pub unsafe fn cat_to_str_unchecked(&self, cat: CatSize) -> &str {
109        unsafe { self.cat_to_str_and_hash.get_unchecked(cat as usize).0 }
110    }
111
112    /// Try to convert a categorical id to the hash of its corresponding string,
113    /// returning None if the string is not in the data structure.
114    #[inline(always)]
115    pub fn cat_to_hash(&self, cat: CatSize) -> Option<u64> {
116        self.cat_to_str_and_hash.get(cat as usize).map(|o| o.1)
117    }
118
119    /// Get the hash of the string corresponding to a categorical id.
120    ///
121    /// # Safety
122    /// The categorical id must have been returned from `to_cat`, and you must
123    /// have synchronized with the call which inserted it.
124    #[inline(always)]
125    pub unsafe fn cat_to_hash_unchecked(&self, cat: CatSize) -> u64 {
126        unsafe { self.cat_to_str_and_hash.get_unchecked(cat as usize).1 }
127    }
128
129    /// Returns an upper bound such that all strings inserted into the CategoricalMapping
130    /// have a categorical id less than it. Note that due to parallel inserts which
131    /// you have not synchronized with, there may be gaps when using `from_cat`.
132    #[inline(always)]
133    pub fn num_cats_upper_bound(&self) -> usize {
134        // We need to clamp to self.max_categories because a `fetch_add` may
135        // have (temporarily) pushed it beyond the max allowed.
136        self.upper_bound
137            .load(Ordering::Relaxed)
138            .min(self.max_categories)
139    }
140
141    /// Returns the number of categories in this mapping.
142    ///
143    /// This requires exclusive `&mut` access to ensure there are no insertions in-flight.
144    #[inline(always)]
145    pub fn len(&mut self) -> usize {
146        *self.upper_bound.get_mut()
147    }
148
149    #[inline(always)]
150    pub fn is_empty(&mut self) -> bool {
151        self.len() == 0
152    }
153
154    pub fn to_arrow(&self, as_views: bool) -> Box<dyn Array> {
155        let n = self.num_cats_upper_bound();
156        if as_views {
157            let mut builder = Utf8ViewArrayBuilder::new(ArrowDataType::Utf8View);
158            builder.reserve(n);
159            for i in 0..n {
160                let s = self.cat_to_str(i as CatSize).unwrap_or_default();
161                builder.push_value_ignore_validity(s);
162            }
163            builder.freeze().boxed()
164        } else {
165            let mut builder = MutableUtf8Array::new();
166            builder.reserve(n, 0);
167            for i in 0..n {
168                let s = self.cat_to_str(i as CatSize).unwrap_or_default();
169                builder.push(Some(s));
170            }
171            let arr: Utf8Array<i64> = builder.into();
172            arr.boxed()
173        }
174    }
175}
176
177impl fmt::Debug for CategoricalMapping {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        f.debug_struct("CategoricalMapping")
180            .field("max_categories", &self.max_categories)
181            .field("upper_bound", &self.upper_bound.load(Ordering::Relaxed))
182            .finish()
183    }
184}