polars_dtype/categorical/
mapping.rs1use std::fmt;
2use std::hash::BuildHasher;
3use std::sync::atomic::{AtomicUsize, Ordering};
4
5use arrow::array::builder::StaticArrayBuilder;
6use arrow::array::{Array, MutableUtf8Array, Utf8Array, Utf8ViewArrayBuilder};
7use arrow::datatypes::ArrowDataType;
8use polars_error::{PolarsResult, polars_bail};
9use polars_utils::aliases::{PlFixedStateQuality, PlSeedableRandomStateQuality};
10use polars_utils::parma::raw::RawTable;
11
12use super::CatSize;
13
14pub struct CategoricalMapping {
15 str_to_cat: RawTable<str, CatSize>,
16 cat_to_str_and_hash: boxcar::Vec<(&'static str, u64)>,
17 max_categories: usize,
18 upper_bound: AtomicUsize,
19 hasher: PlSeedableRandomStateQuality,
20}
21
22impl CategoricalMapping {
23 pub(crate) fn new(max_categories: usize) -> Self {
24 Self::with_hasher(max_categories, PlSeedableRandomStateQuality::default())
25 }
26
27 pub fn with_hasher(max_categories: usize, hasher: PlSeedableRandomStateQuality) -> Self {
28 Self {
29 str_to_cat: RawTable::default(),
30 cat_to_str_and_hash: boxcar::Vec::default(),
31 max_categories,
32 upper_bound: AtomicUsize::new(0),
33 hasher,
34 }
35 }
36
37 #[inline(always)]
38 pub fn hasher(&self) -> &PlSeedableRandomStateQuality {
39 &self.hasher
40 }
41
42 pub fn max_categories(&self) -> usize {
43 self.max_categories
44 }
45
46 pub fn set_max_categories(&mut self, max_categories: usize) {
47 assert!(max_categories >= self.num_cats_upper_bound());
48 self.max_categories = max_categories
49 }
50
51 #[inline(always)]
53 pub fn get_cat(&self, s: &str) -> Option<CatSize> {
54 let hash = self.hasher.hash_one(s);
55 self.get_cat_with_hash(s, hash)
56 }
57
58 #[inline(always)]
60 pub fn get_cat_with_hash(&self, s: &str, hash: u64) -> Option<CatSize> {
61 self.str_to_cat.get(hash, |k| k == s).copied()
62 }
63
64 #[inline(always)]
66 pub fn insert_cat(&self, s: &str) -> PolarsResult<CatSize> {
67 let hash = self.hasher.hash_one(s);
68 self.insert_cat_with_hash(s, hash)
69 }
70
71 #[inline(always)]
73 pub fn insert_cat_with_hash(&self, s: &str, hash: u64) -> PolarsResult<CatSize> {
74 self.str_to_cat
75 .try_get_or_insert_with(
76 hash,
77 s,
78 |k| k == s,
79 |k| {
80 let old_upper_bound = self.upper_bound.fetch_add(1, Ordering::Relaxed);
81 if old_upper_bound + 1 > self.max_categories {
82 self.upper_bound.fetch_sub(1, Ordering::Relaxed);
83 polars_bail!(ComputeError: "attempted to insert more categories than the maximum allowed");
84 }
85 let hash = PlFixedStateQuality::default().hash_one(k);
86 let idx = self
87 .cat_to_str_and_hash
88 .push((unsafe { core::mem::transmute::<&str, &'static str>(k) }, hash));
89 Ok(idx as CatSize)
90 },
91 )
92 .copied()
93 }
94
95 #[inline(always)]
98 pub fn cat_to_str(&self, cat: CatSize) -> Option<&str> {
99 self.cat_to_str_and_hash.get(cat as usize).map(|o| o.0)
100 }
101
102 #[inline(always)]
108 pub unsafe fn cat_to_str_unchecked(&self, cat: CatSize) -> &str {
109 unsafe { self.cat_to_str_and_hash.get_unchecked(cat as usize).0 }
110 }
111
112 #[inline(always)]
115 pub fn cat_to_hash(&self, cat: CatSize) -> Option<u64> {
116 self.cat_to_str_and_hash.get(cat as usize).map(|o| o.1)
117 }
118
119 #[inline(always)]
125 pub unsafe fn cat_to_hash_unchecked(&self, cat: CatSize) -> u64 {
126 unsafe { self.cat_to_str_and_hash.get_unchecked(cat as usize).1 }
127 }
128
129 #[inline(always)]
133 pub fn num_cats_upper_bound(&self) -> usize {
134 self.upper_bound
137 .load(Ordering::Relaxed)
138 .min(self.max_categories)
139 }
140
141 #[inline(always)]
145 pub fn len(&mut self) -> usize {
146 *self.upper_bound.get_mut()
147 }
148
149 #[inline(always)]
150 pub fn is_empty(&mut self) -> bool {
151 self.len() == 0
152 }
153
154 pub fn to_arrow(&self, as_views: bool) -> Box<dyn Array> {
155 let n = self.num_cats_upper_bound();
156 if as_views {
157 let mut builder = Utf8ViewArrayBuilder::new(ArrowDataType::Utf8View);
158 builder.reserve(n);
159 for i in 0..n {
160 let s = self.cat_to_str(i as CatSize).unwrap_or_default();
161 builder.push_value_ignore_validity(s);
162 }
163 builder.freeze().boxed()
164 } else {
165 let mut builder = MutableUtf8Array::new();
166 builder.reserve(n, 0);
167 for i in 0..n {
168 let s = self.cat_to_str(i as CatSize).unwrap_or_default();
169 builder.push(Some(s));
170 }
171 let arr: Utf8Array<i64> = builder.into();
172 arr.boxed()
173 }
174 }
175}
176
177impl fmt::Debug for CategoricalMapping {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 f.debug_struct("CategoricalMapping")
180 .field("max_categories", &self.max_categories)
181 .field("upper_bound", &self.upper_bound.load(Ordering::Relaxed))
182 .finish()
183 }
184}