Skip to main content

polars_dtype/categorical/
mod.rs

1use std::fmt;
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::str::FromStr;
4use std::sync::{Arc, LazyLock, Mutex, Weak};
5
6use arrow::array::builder::StaticArrayBuilder;
7use arrow::array::{Utf8ViewArray, Utf8ViewArrayBuilder};
8use arrow::datatypes::ArrowDataType;
9use hashbrown::HashTable;
10use hashbrown::hash_table::Entry;
11use polars_error::{PolarsResult, polars_bail, polars_ensure};
12use polars_utils::aliases::*;
13use polars_utils::pl_str::PlSmallStr;
14
15mod catsize;
16mod mapping;
17
18pub use catsize::{CatNative, CatSize};
19pub use mapping::CategoricalMapping;
20
21/// The physical datatype backing a categorical / enum.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
25pub enum CategoricalPhysical {
26    U8,
27    U16,
28    U32,
29}
30
31impl CategoricalPhysical {
32    pub fn max_categories(&self) -> usize {
33        // We might use T::MAX as an indicator, so the maximum number of categories is T::MAX
34        // (giving T::MAX - 1 as the largest category).
35        match self {
36            Self::U8 => u8::MAX as usize,
37            Self::U16 => u16::MAX as usize,
38            Self::U32 => u32::MAX as usize,
39        }
40    }
41
42    pub fn smallest_physical(num_cats: usize) -> PolarsResult<Self> {
43        if num_cats <= u8::MAX as usize {
44            Ok(Self::U8)
45        } else if num_cats <= u16::MAX as usize {
46            Ok(Self::U16)
47        } else if num_cats <= u32::MAX as usize {
48            Ok(Self::U32)
49        } else {
50            polars_bail!(ComputeError: "attempted to insert more categories than the maximum allowed")
51        }
52    }
53
54    pub fn as_str(&self) -> &'static str {
55        match self {
56            Self::U8 => "u8",
57            Self::U16 => "u16",
58            Self::U32 => "u32",
59        }
60    }
61}
62
63impl FromStr for CategoricalPhysical {
64    type Err = ();
65
66    fn from_str(s: &str) -> Result<Self, Self::Err> {
67        match s {
68            "u8" => Ok(Self::U8),
69            "u16" => Ok(Self::U16),
70            "u32" => Ok(Self::U32),
71            _ => Err(()),
72        }
73    }
74}
75
76#[derive(Debug, Clone, PartialEq, Eq, Hash)]
77struct CategoricalId {
78    name: PlSmallStr,
79    namespace: PlSmallStr,
80    physical: CategoricalPhysical,
81}
82
83impl CategoricalId {
84    fn global() -> Self {
85        Self {
86            name: PlSmallStr::from_static(""),
87            namespace: PlSmallStr::from_static(""),
88            physical: CategoricalPhysical::U32,
89        }
90    }
91}
92
93// Used to maintain a 1:1 mapping between Categories' ID and the Categories objects themselves.
94// This is important for serialization.
95static CATEGORIES_REGISTRY: LazyLock<Mutex<PlHashMap<CategoricalId, Weak<Categories>>>> =
96    LazyLock::new(|| Mutex::new(PlHashMap::new()));
97
98// Used to make FrozenCategories unique based on their content. This allows comparison of datatypes
99// in constant time by comparing pointers.
100#[expect(clippy::type_complexity)]
101static FROZEN_CATEGORIES_REGISTRY: LazyLock<Mutex<HashTable<(u64, Weak<FrozenCategories>)>>> =
102    LazyLock::new(|| Mutex::new(HashTable::new()));
103
104static FROZEN_CATEGORIES_HASHER: LazyLock<PlSeedableRandomStateQuality> =
105    LazyLock::new(PlSeedableRandomStateQuality::random);
106
107static GLOBAL_CATEGORIES: LazyLock<Arc<Categories>> = LazyLock::new(|| {
108    let mut registry = CATEGORIES_REGISTRY.lock().unwrap();
109    let global_id = CategoricalId::global();
110    if let Some(cats_ref) = registry.get(&global_id) {
111        if let Some(cats) = cats_ref.upgrade() {
112            return cats;
113        }
114    }
115    let global = Arc::new(Categories {
116        id: CategoricalId::global(),
117        mapping: Mutex::new(Weak::new()),
118    });
119    registry.insert(global_id, Arc::downgrade(&global));
120    global
121});
122
123/// A (named) object which is used to indicate which categorical data types have the same mapping.
124pub struct Categories {
125    id: CategoricalId,
126    mapping: Mutex<Weak<CategoricalMapping>>,
127}
128
129impl Categories {
130    /// Creates a new Categories object with the given name, namespace and physical type if none exists, otherwise
131    /// get a reference to an existing object with the same name, namespace and physical type.
132    pub fn new(
133        name: PlSmallStr,
134        namespace: PlSmallStr,
135        physical: CategoricalPhysical,
136    ) -> Arc<Self> {
137        let id = CategoricalId {
138            name,
139            namespace,
140            physical,
141        };
142        let mut registry = CATEGORIES_REGISTRY.lock().unwrap();
143        if let Some(cats_ref) = registry.get(&id) {
144            if let Some(cats) = cats_ref.upgrade() {
145                return cats;
146            }
147        }
148        let mapping = Mutex::new(Weak::new());
149        let slf = Arc::new(Self {
150            id: id.clone(),
151            mapping,
152        });
153        registry.insert(id, Arc::downgrade(&slf));
154        slf
155    }
156
157    /// Returns the global Categories.
158    pub fn global() -> Arc<Self> {
159        GLOBAL_CATEGORIES.clone()
160    }
161
162    /// Returns whether this refers to the global categories.
163    pub fn is_global(self: &Arc<Self>) -> bool {
164        Arc::ptr_eq(self, &*GLOBAL_CATEGORIES)
165    }
166
167    /// Generates a Categories with a random (UUID) name.
168    pub fn random(namespace: PlSmallStr, physical: CategoricalPhysical) -> Arc<Self> {
169        Self::new(uuid::Uuid::new_v4().to_string().into(), namespace, physical)
170    }
171
172    /// The name of this Categories object.
173    pub fn name(&self) -> &PlSmallStr {
174        &self.id.name
175    }
176
177    /// The namespace of this Categories object.
178    pub fn namespace(&self) -> &PlSmallStr {
179        &self.id.namespace
180    }
181
182    /// The physical dtype of the category ids.
183    pub fn physical(&self) -> CategoricalPhysical {
184        self.id.physical
185    }
186
187    /// A stable hash of this Categories object, not the contained categories.
188    pub fn hash(&self) -> u64 {
189        PlFixedStateQuality::default().hash_one(&self.id)
190    }
191
192    /// The mapping for this Categories object. If no mapping currently exists
193    /// it creates a new empty mapping.
194    pub fn mapping(&self) -> Arc<CategoricalMapping> {
195        let mut guard = self.mapping.lock().unwrap();
196        if let Some(arc) = guard.upgrade() {
197            return arc;
198        }
199        let arc = Arc::new(CategoricalMapping::new(self.id.physical.max_categories()));
200        *guard = Arc::downgrade(&arc);
201        arc
202    }
203
204    pub fn freeze(&self) -> Arc<FrozenCategories> {
205        let mapping = self.mapping();
206        let n = mapping.num_cats_upper_bound();
207        FrozenCategories::new((0..n).flat_map(|i| mapping.cat_to_str(i as CatSize))).unwrap()
208    }
209}
210
211impl fmt::Debug for Categories {
212    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213        f.debug_struct("Categories")
214            .field("name", &self.id.name)
215            .field("namespace", &self.id.namespace)
216            .field("physical", &self.id.physical)
217            .finish()
218    }
219}
220
221impl Drop for Categories {
222    fn drop(&mut self) {
223        CATEGORIES_REGISTRY.lock().unwrap().remove(&self.id);
224    }
225}
226
227/// An ordered collection of unique strings with an associated pre-computed
228/// mapping to go from string <-> index.
229///
230/// FrozenCategories are globally unique to facilitate constant-time comparison.
231pub struct FrozenCategories {
232    physical: CategoricalPhysical,
233    combined_hash: u64,
234    categories: Utf8ViewArray,
235    mapping: Arc<CategoricalMapping>,
236}
237
238impl FrozenCategories {
239    /// Creates a new FrozenCategories object (or returns a reference to an existing one
240    /// in case these are already known). Returns an error if the categories are not unique.
241    /// It is guaranteed that the nth string ends up with category n (0-indexed).
242    pub fn new<'a, I: IntoIterator<Item = &'a str>>(strings: I) -> PolarsResult<Arc<Self>> {
243        let strings = strings.into_iter();
244        let hasher = FROZEN_CATEGORIES_HASHER.clone();
245        let mut mapping = CategoricalMapping::with_hasher(usize::MAX, hasher);
246        let mut builder = Utf8ViewArrayBuilder::new(ArrowDataType::Utf8View);
247        builder.reserve(strings.size_hint().0);
248
249        let mut combined_hasher = PlFixedStateQuality::default().build_hasher();
250        for s in strings {
251            combined_hasher.write(s.as_bytes());
252            mapping.insert_cat(s)?;
253            builder.push_value_ignore_validity(s);
254            polars_ensure!(mapping.len() == builder.len(), ComputeError: "FrozenCategories must contain unique strings; found duplicate '{s}'");
255        }
256
257        let combined_hash = combined_hasher.finish();
258        let categories = builder.freeze();
259        mapping.set_max_categories(categories.len()); // Don't allow any further inserts.
260
261        let physical = CategoricalPhysical::smallest_physical(categories.len())?;
262        let mut registry = FROZEN_CATEGORIES_REGISTRY.lock().unwrap();
263        let mut last_compared = None; // We have to store the strong reference to avoid a race condition.
264        match registry.entry(
265            combined_hash,
266            |(hash, weak)| {
267                *hash == combined_hash && {
268                    if let Some(frozen_cats) = weak.upgrade() {
269                        let cmp = frozen_cats.categories == categories;
270                        last_compared = Some(frozen_cats);
271                        cmp
272                    } else {
273                        false
274                    }
275                }
276            },
277            |(hash, _weak)| *hash,
278        ) {
279            Entry::Occupied(_) => Ok(last_compared.unwrap()),
280            Entry::Vacant(v) => {
281                let slf = Arc::new(Self {
282                    physical,
283                    combined_hash,
284                    categories,
285                    mapping: Arc::new(mapping),
286                });
287                v.insert((combined_hash, Arc::downgrade(&slf)));
288                Ok(slf)
289            },
290        }
291    }
292
293    /// The categories contained in this FrozenCategories object.
294    pub fn categories(&self) -> &Utf8ViewArray {
295        &self.categories
296    }
297
298    /// The physical dtype of the category ids.
299    pub fn physical(&self) -> CategoricalPhysical {
300        self.physical
301    }
302
303    /// The mapping for this FrozenCategories object.
304    pub fn mapping(&self) -> &Arc<CategoricalMapping> {
305        &self.mapping
306    }
307
308    /// A stable hash of the categories.
309    pub fn hash(&self) -> u64 {
310        self.combined_hash
311    }
312}
313
314impl fmt::Debug for FrozenCategories {
315    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316        f.debug_struct("FrozenCategories")
317            .field("physical", &self.physical)
318            .field("categories", &self.categories)
319            .finish()
320    }
321}
322
323impl Drop for FrozenCategories {
324    fn drop(&mut self) {
325        let mut registry = FROZEN_CATEGORIES_REGISTRY.lock().unwrap();
326        while let Ok(entry) =
327            registry.find_entry(self.combined_hash, |(_, weak)| weak.strong_count() == 0)
328        {
329            entry.remove();
330        }
331    }
332}
333
334pub fn ensure_same_categories(left: &Arc<Categories>, right: &Arc<Categories>) -> PolarsResult<()> {
335    if Arc::ptr_eq(left, right) {
336        return Ok(());
337    }
338
339    if left.name() != right.name() {
340        polars_bail!(SchemaMismatch: "Categories name mismatch, left: '{}', right: '{}'.
341
342Operations mixing different Categories are often not supported, you may have to cast.", left.name(), right.name())
343    } else if left.namespace() != right.namespace() {
344        polars_bail!(SchemaMismatch: "Categories have same name ('{}'), but have a mismatch in namespace, left: {}, right: {}.
345
346Operations mixing different Categories are often not supported, you may have to cast.", left.name(), left.namespace(), right.namespace())
347    } else if left.physical() != right.physical() {
348        polars_bail!(SchemaMismatch: "Categories have same name and namespace ('{}', {}), but have a mismatch in dtype, left: {}, right: {}.
349
350Operations mixing different Categories are often not supported, you may have to cast.", left.name(), left.namespace(), left.physical().as_str(), right.physical().as_str())
351    } else {
352        polars_bail!(SchemaMismatch: "Categories which should be equal have different backing objects.
353
354This is a known problem when combining Polars with multiprocessing using fork().")
355    }
356}
357
358pub fn ensure_same_frozen_categories(
359    left: &Arc<FrozenCategories>,
360    right: &Arc<FrozenCategories>,
361) -> PolarsResult<()> {
362    if Arc::ptr_eq(left, right) {
363        return Ok(());
364    }
365
366    polars_bail!(SchemaMismatch: r#"Enum mismatch.
367
368Operations mixing different Enums are often not supported, you may have to cast."#)
369}