polars_dtype/categorical/
mod.rs1use std::fmt;
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::str::FromStr;
4use std::sync::{Arc, LazyLock, Mutex, Weak};
5
6use arrow::array::builder::StaticArrayBuilder;
7use arrow::array::{Utf8ViewArray, Utf8ViewArrayBuilder};
8use arrow::datatypes::ArrowDataType;
9use hashbrown::HashTable;
10use hashbrown::hash_table::Entry;
11use polars_error::{PolarsResult, polars_bail, polars_ensure};
12use polars_utils::aliases::*;
13use polars_utils::pl_str::PlSmallStr;
14
15mod catsize;
16mod mapping;
17
18pub use catsize::{CatNative, CatSize};
19pub use mapping::CategoricalMapping;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
25pub enum CategoricalPhysical {
26 U8,
27 U16,
28 U32,
29}
30
31impl CategoricalPhysical {
32 pub fn max_categories(&self) -> usize {
33 match self {
36 Self::U8 => u8::MAX as usize,
37 Self::U16 => u16::MAX as usize,
38 Self::U32 => u32::MAX as usize,
39 }
40 }
41
42 pub fn smallest_physical(num_cats: usize) -> PolarsResult<Self> {
43 if num_cats <= u8::MAX as usize {
44 Ok(Self::U8)
45 } else if num_cats <= u16::MAX as usize {
46 Ok(Self::U16)
47 } else if num_cats <= u32::MAX as usize {
48 Ok(Self::U32)
49 } else {
50 polars_bail!(ComputeError: "attempted to insert more categories than the maximum allowed")
51 }
52 }
53
54 pub fn as_str(&self) -> &'static str {
55 match self {
56 Self::U8 => "u8",
57 Self::U16 => "u16",
58 Self::U32 => "u32",
59 }
60 }
61}
62
63impl FromStr for CategoricalPhysical {
64 type Err = ();
65
66 fn from_str(s: &str) -> Result<Self, Self::Err> {
67 match s {
68 "u8" => Ok(Self::U8),
69 "u16" => Ok(Self::U16),
70 "u32" => Ok(Self::U32),
71 _ => Err(()),
72 }
73 }
74}
75
76#[derive(Debug, Clone, PartialEq, Eq, Hash)]
77struct CategoricalId {
78 name: PlSmallStr,
79 namespace: PlSmallStr,
80 physical: CategoricalPhysical,
81}
82
83impl CategoricalId {
84 fn global() -> Self {
85 Self {
86 name: PlSmallStr::from_static(""),
87 namespace: PlSmallStr::from_static(""),
88 physical: CategoricalPhysical::U32,
89 }
90 }
91}
92
93static CATEGORIES_REGISTRY: LazyLock<Mutex<PlHashMap<CategoricalId, Weak<Categories>>>> =
96 LazyLock::new(|| Mutex::new(PlHashMap::new()));
97
98#[expect(clippy::type_complexity)]
101static FROZEN_CATEGORIES_REGISTRY: LazyLock<Mutex<HashTable<(u64, Weak<FrozenCategories>)>>> =
102 LazyLock::new(|| Mutex::new(HashTable::new()));
103
104static FROZEN_CATEGORIES_HASHER: LazyLock<PlSeedableRandomStateQuality> =
105 LazyLock::new(PlSeedableRandomStateQuality::random);
106
107static GLOBAL_CATEGORIES: LazyLock<Arc<Categories>> = LazyLock::new(|| {
108 let mut registry = CATEGORIES_REGISTRY.lock().unwrap();
109 let global_id = CategoricalId::global();
110 if let Some(cats_ref) = registry.get(&global_id) {
111 if let Some(cats) = cats_ref.upgrade() {
112 return cats;
113 }
114 }
115 let global = Arc::new(Categories {
116 id: CategoricalId::global(),
117 mapping: Mutex::new(Weak::new()),
118 });
119 registry.insert(global_id, Arc::downgrade(&global));
120 global
121});
122
123pub struct Categories {
125 id: CategoricalId,
126 mapping: Mutex<Weak<CategoricalMapping>>,
127}
128
129impl Categories {
130 pub fn new(
133 name: PlSmallStr,
134 namespace: PlSmallStr,
135 physical: CategoricalPhysical,
136 ) -> Arc<Self> {
137 let id = CategoricalId {
138 name,
139 namespace,
140 physical,
141 };
142 let mut registry = CATEGORIES_REGISTRY.lock().unwrap();
143 if let Some(cats_ref) = registry.get(&id) {
144 if let Some(cats) = cats_ref.upgrade() {
145 return cats;
146 }
147 }
148 let mapping = Mutex::new(Weak::new());
149 let slf = Arc::new(Self {
150 id: id.clone(),
151 mapping,
152 });
153 registry.insert(id, Arc::downgrade(&slf));
154 slf
155 }
156
157 pub fn global() -> Arc<Self> {
159 GLOBAL_CATEGORIES.clone()
160 }
161
162 pub fn is_global(self: &Arc<Self>) -> bool {
164 Arc::ptr_eq(self, &*GLOBAL_CATEGORIES)
165 }
166
167 pub fn random(namespace: PlSmallStr, physical: CategoricalPhysical) -> Arc<Self> {
169 Self::new(uuid::Uuid::new_v4().to_string().into(), namespace, physical)
170 }
171
172 pub fn name(&self) -> &PlSmallStr {
174 &self.id.name
175 }
176
177 pub fn namespace(&self) -> &PlSmallStr {
179 &self.id.namespace
180 }
181
182 pub fn physical(&self) -> CategoricalPhysical {
184 self.id.physical
185 }
186
187 pub fn hash(&self) -> u64 {
189 PlFixedStateQuality::default().hash_one(&self.id)
190 }
191
192 pub fn mapping(&self) -> Arc<CategoricalMapping> {
195 let mut guard = self.mapping.lock().unwrap();
196 if let Some(arc) = guard.upgrade() {
197 return arc;
198 }
199 let arc = Arc::new(CategoricalMapping::new(self.id.physical.max_categories()));
200 *guard = Arc::downgrade(&arc);
201 arc
202 }
203
204 pub fn freeze(&self) -> Arc<FrozenCategories> {
205 let mapping = self.mapping();
206 let n = mapping.num_cats_upper_bound();
207 FrozenCategories::new((0..n).flat_map(|i| mapping.cat_to_str(i as CatSize))).unwrap()
208 }
209}
210
211impl fmt::Debug for Categories {
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 f.debug_struct("Categories")
214 .field("name", &self.id.name)
215 .field("namespace", &self.id.namespace)
216 .field("physical", &self.id.physical)
217 .finish()
218 }
219}
220
221impl Drop for Categories {
222 fn drop(&mut self) {
223 CATEGORIES_REGISTRY.lock().unwrap().remove(&self.id);
224 }
225}
226
227pub struct FrozenCategories {
232 physical: CategoricalPhysical,
233 combined_hash: u64,
234 categories: Utf8ViewArray,
235 mapping: Arc<CategoricalMapping>,
236}
237
238impl FrozenCategories {
239 pub fn new<'a, I: IntoIterator<Item = &'a str>>(strings: I) -> PolarsResult<Arc<Self>> {
243 let strings = strings.into_iter();
244 let hasher = FROZEN_CATEGORIES_HASHER.clone();
245 let mut mapping = CategoricalMapping::with_hasher(usize::MAX, hasher);
246 let mut builder = Utf8ViewArrayBuilder::new(ArrowDataType::Utf8View);
247 builder.reserve(strings.size_hint().0);
248
249 let mut combined_hasher = PlFixedStateQuality::default().build_hasher();
250 for s in strings {
251 combined_hasher.write(s.as_bytes());
252 mapping.insert_cat(s)?;
253 builder.push_value_ignore_validity(s);
254 polars_ensure!(mapping.len() == builder.len(), ComputeError: "FrozenCategories must contain unique strings; found duplicate '{s}'");
255 }
256
257 let combined_hash = combined_hasher.finish();
258 let categories = builder.freeze();
259 mapping.set_max_categories(categories.len()); let physical = CategoricalPhysical::smallest_physical(categories.len())?;
262 let mut registry = FROZEN_CATEGORIES_REGISTRY.lock().unwrap();
263 let mut last_compared = None; match registry.entry(
265 combined_hash,
266 |(hash, weak)| {
267 *hash == combined_hash && {
268 if let Some(frozen_cats) = weak.upgrade() {
269 let cmp = frozen_cats.categories == categories;
270 last_compared = Some(frozen_cats);
271 cmp
272 } else {
273 false
274 }
275 }
276 },
277 |(hash, _weak)| *hash,
278 ) {
279 Entry::Occupied(_) => Ok(last_compared.unwrap()),
280 Entry::Vacant(v) => {
281 let slf = Arc::new(Self {
282 physical,
283 combined_hash,
284 categories,
285 mapping: Arc::new(mapping),
286 });
287 v.insert((combined_hash, Arc::downgrade(&slf)));
288 Ok(slf)
289 },
290 }
291 }
292
293 pub fn categories(&self) -> &Utf8ViewArray {
295 &self.categories
296 }
297
298 pub fn physical(&self) -> CategoricalPhysical {
300 self.physical
301 }
302
303 pub fn mapping(&self) -> &Arc<CategoricalMapping> {
305 &self.mapping
306 }
307
308 pub fn hash(&self) -> u64 {
310 self.combined_hash
311 }
312}
313
314impl fmt::Debug for FrozenCategories {
315 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316 f.debug_struct("FrozenCategories")
317 .field("physical", &self.physical)
318 .field("categories", &self.categories)
319 .finish()
320 }
321}
322
323impl Drop for FrozenCategories {
324 fn drop(&mut self) {
325 let mut registry = FROZEN_CATEGORIES_REGISTRY.lock().unwrap();
326 while let Ok(entry) =
327 registry.find_entry(self.combined_hash, |(_, weak)| weak.strong_count() == 0)
328 {
329 entry.remove();
330 }
331 }
332}
333
334pub fn ensure_same_categories(left: &Arc<Categories>, right: &Arc<Categories>) -> PolarsResult<()> {
335 if Arc::ptr_eq(left, right) {
336 return Ok(());
337 }
338
339 if left.name() != right.name() {
340 polars_bail!(SchemaMismatch: "Categories name mismatch, left: '{}', right: '{}'.
341
342Operations mixing different Categories are often not supported, you may have to cast.", left.name(), right.name())
343 } else if left.namespace() != right.namespace() {
344 polars_bail!(SchemaMismatch: "Categories have same name ('{}'), but have a mismatch in namespace, left: {}, right: {}.
345
346Operations mixing different Categories are often not supported, you may have to cast.", left.name(), left.namespace(), right.namespace())
347 } else if left.physical() != right.physical() {
348 polars_bail!(SchemaMismatch: "Categories have same name and namespace ('{}', {}), but have a mismatch in dtype, left: {}, right: {}.
349
350Operations mixing different Categories are often not supported, you may have to cast.", left.name(), left.namespace(), left.physical().as_str(), right.physical().as_str())
351 } else {
352 polars_bail!(SchemaMismatch: "Categories which should be equal have different backing objects.
353
354This is a known problem when combining Polars with multiprocessing using fork().")
355 }
356}
357
358pub fn ensure_same_frozen_categories(
359 left: &Arc<FrozenCategories>,
360 right: &Arc<FrozenCategories>,
361) -> PolarsResult<()> {
362 if Arc::ptr_eq(left, right) {
363 return Ok(());
364 }
365
366 polars_bail!(SchemaMismatch: r#"Enum mismatch.
367
368Operations mixing different Enums are often not supported, you may have to cast."#)
369}