#[cfg(not(feature = "fast_dict"))]
use std::collections::hash_map::DefaultHasher;
use std::fmt;
#[cfg(not(feature = "fast_dict"))]
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use ::vec64::Vec64;
use ::vec64::AppendOnlyVec;
use crate::traits::type_unions::Integer;
#[cfg(feature = "fast_dict")]
type ShardMutex<T> = parking_lot::Mutex<T>;
#[cfg(not(feature = "fast_dict"))]
type ShardMutex<T> = std::sync::Mutex<T>;
#[cfg(feature = "fast_dict")]
type IndexMap<T> = hashbrown::HashMap<String, T, ahash::RandomState>;
#[cfg(all(not(feature = "fast_dict"), feature = "fast_hash"))]
type IndexMap<T> = ahash::AHashMap<String, T>;
#[cfg(all(not(feature = "fast_dict"), not(feature = "fast_hash")))]
type IndexMap<T> = std::collections::HashMap<String, T>;
const N_INDEX_SHARDS: usize = 64;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum DictionaryError {
Overflow,
}
impl fmt::Display for DictionaryError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Overflow => write!(
f,
"dictionary cardinality would exceed the capacity of the index type"
),
}
}
}
impl std::error::Error for DictionaryError {}
struct ShardedIndex<T: Integer> {
#[cfg(feature = "fast_dict")]
hasher: ahash::RandomState,
shards: Box<[ShardMutex<IndexMap<T>>; N_INDEX_SHARDS]>,
}
impl<T: Integer> Default for ShardedIndex<T> {
fn default() -> Self {
#[cfg(feature = "fast_dict")]
{
let hasher = ahash::RandomState::new();
let shards: [ShardMutex<IndexMap<T>>; N_INDEX_SHARDS] =
std::array::from_fn(|_| ShardMutex::new(IndexMap::with_hasher(hasher.clone())));
Self {
hasher,
shards: Box::new(shards),
}
}
#[cfg(not(feature = "fast_dict"))]
{
let shards: [ShardMutex<IndexMap<T>>; N_INDEX_SHARDS] =
std::array::from_fn(|_| ShardMutex::new(IndexMap::default()));
Self {
shards: Box::new(shards),
}
}
}
}
impl<T: Integer> ShardedIndex<T> {
#[cfg(not(feature = "fast_dict"))]
#[inline]
#[allow(clippy::unused_self)]
fn shard_for(&self, s: &str) -> usize {
let mut h = DefaultHasher::new();
s.hash(&mut h);
(h.finish() as usize) % N_INDEX_SHARDS
}
}
pub struct DictionaryInner<T: Integer> {
pub values: AppendOnlyVec<String>,
index: ShardedIndex<T>,
}
impl<T: Integer> Default for DictionaryInner<T> {
fn default() -> Self {
Self {
values: AppendOnlyVec::with_capacity(max_cap::<T>()),
index: ShardedIndex::default(),
}
}
}
#[derive(Clone)]
pub struct Dictionary<T: Integer> {
inner: Arc<DictionaryInner<T>>,
}
impl<T: Integer> Default for Dictionary<T> {
fn default() -> Self {
Self {
inner: Arc::new(DictionaryInner::default()),
}
}
}
impl<T: Integer> Dictionary<T> {
#[inline]
pub fn new() -> Self {
Self::default()
}
pub fn from_values(values: impl Into<Vec64<String>>) -> Self {
let values: Vec64<String> = values.into();
let d = Self::default();
let cap = d.inner.values.capacity();
for s in values {
#[cfg(feature = "fast_dict")]
{
use hashbrown::hash_map::RawEntryMut;
let values_ref = &d.inner.values;
let hash = d.inner.index.hasher.hash_one(&s);
let shard_idx = (hash as usize) & (N_INDEX_SHARDS - 1);
let shard = &d.inner.index.shards[shard_idx];
let mut g = shard.lock();
match g.raw_entry_mut().from_hash(hash, |k: &String| k.as_str() == s.as_str()) {
RawEntryMut::Occupied(_) => continue,
RawEntryMut::Vacant(vac) => {
assert!(
values_ref.count() < cap,
"Dictionary input has more unique values than the capacity of {} ({})",
std::any::type_name::<T>(),
cap
);
let idx = unsafe { values_ref.push(s.clone()).unwrap_unchecked() };
vac.insert_hashed_nocheck(hash, s, T::from_usize(idx));
}
}
}
#[cfg(not(feature = "fast_dict"))]
{
let shard = &d.inner.index.shards[d.inner.index.shard_for(&s)];
let mut g = shard.lock().unwrap_or_else(|p| {
shard.clear_poison();
p.into_inner()
});
if g.get(&s).is_some() {
continue;
}
assert!(
d.inner.values.count() < cap,
"Dictionary input has more unique values than the capacity of {} ({})",
std::any::type_name::<T>(),
cap
);
let idx = unsafe { d.inner.values.push(s.clone()).unwrap_unchecked() };
g.insert(s, T::from_usize(idx));
}
}
d
}
#[inline]
pub fn values(&self) -> &[String] {
self.inner.values.as_slice()
}
#[inline]
pub fn len(&self) -> usize {
self.inner.values.count()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn lookup(&self, s: &str) -> Option<T> {
#[cfg(feature = "fast_dict")]
{
let hash = self.inner.index.hasher.hash_one(s);
let shard_idx = (hash as usize) & (N_INDEX_SHARDS - 1);
let shard = &self.inner.index.shards[shard_idx];
let g = shard.lock();
g.raw_entry()
.from_hash(hash, |k: &String| k.as_str() == s)
.map(|(_, v)| *v)
}
#[cfg(not(feature = "fast_dict"))]
{
let shard = &self.inner.index.shards[self.inner.index.shard_for(s)];
shard
.lock()
.unwrap_or_else(|p| {
shard.clear_poison();
p.into_inner()
})
.get(s)
.copied()
}
}
pub fn add_cat(&self, value: &str) -> Result<T, DictionaryError> {
#[cfg(feature = "fast_dict")]
{
use hashbrown::hash_map::RawEntryMut;
let hash = self.inner.index.hasher.hash_one(value);
let shard_idx = (hash as usize) & (N_INDEX_SHARDS - 1);
let shard = &self.inner.index.shards[shard_idx];
let mut g = shard.lock();
match g.raw_entry_mut().from_hash(hash, |k: &String| k.as_str() == value) {
RawEntryMut::Occupied(e) => Ok(*e.get()),
RawEntryMut::Vacant(vac) => {
let owned = value.to_owned();
let idx = self
.inner
.values
.push(owned.clone())
.ok_or(DictionaryError::Overflow)?;
let code = T::from_usize(idx);
vac.insert_hashed_nocheck(hash, owned, code);
Ok(code)
}
}
}
#[cfg(not(feature = "fast_dict"))]
{
let shard = &self.inner.index.shards[self.inner.index.shard_for(value)];
let mut g = shard.lock().unwrap_or_else(|p| {
shard.clear_poison();
p.into_inner()
});
if let Some(&code) = g.get(value) {
return Ok(code);
}
let idx = self
.inner
.values
.push(value.to_owned())
.ok_or(DictionaryError::Overflow)?;
let code = T::from_usize(idx);
g.insert(value.to_owned(), code);
Ok(code)
}
}
pub fn add_remap_cat(&self, cat: &mut crate::CategoricalArray<T>) {
let incoming = &cat.dictionary.inner.values;
let mut shifted = false;
let mut remap: Vec<T> = Vec::with_capacity(incoming.count());
for (incoming_code, s) in incoming.iter() {
let Ok(new_code) = self.add_cat(s) else { return };
if new_code.to_usize() != incoming_code {
shifted = true;
}
remap.push(new_code);
}
if shifted {
for code in cat.data.iter_mut() {
*code = remap[code.to_usize()];
}
}
cat.dictionary = self.clone();
}
pub fn is_prefix_of(&self, other: &Self) -> bool {
let a = &self.inner.values;
let b = &other.inner.values;
if a.count() > b.count() {
return false;
}
for (i, s) in a.iter() {
match b.get(i) {
Some(t) if t.as_str() == s.as_str() => {}
_ => return false,
}
}
true
}
#[inline]
pub fn shares_with(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner, &other.inner)
}
pub fn detach_to_owned(&mut self) {
let fresh = Dictionary::<T>::default();
for (_, s) in self.inner.values.iter() {
#[cfg(feature = "fast_dict")]
{
use hashbrown::hash_map::RawEntryMut;
let values_ref = &fresh.inner.values;
let hash = fresh.inner.index.hasher.hash_one(s);
let shard_idx = (hash as usize) & (N_INDEX_SHARDS - 1);
let shard = &fresh.inner.index.shards[shard_idx];
let mut g = shard.lock();
let vac = match g.raw_entry_mut().from_hash(hash, |k: &String| k.as_str() == s.as_str()) {
RawEntryMut::Vacant(v) => v,
RawEntryMut::Occupied(_) => continue,
};
let idx = unsafe { values_ref.push(s.clone()).unwrap_unchecked() };
vac.insert_hashed_nocheck(hash, s.clone(), T::from_usize(idx));
}
#[cfg(not(feature = "fast_dict"))]
{
let shard = &fresh.inner.index.shards[fresh.inner.index.shard_for(s)];
let mut g = shard.lock().unwrap_or_else(|p| {
shard.clear_poison();
p.into_inner()
});
let idx = unsafe { fresh.inner.values.push(s.clone()).unwrap_unchecked() };
g.insert(s.clone(), T::from_usize(idx));
}
}
self.inner = fresh.inner;
}
pub fn try_values_iter_mut(
&mut self,
) -> Option<std::slice::IterMut<'_, String>> {
Arc::get_mut(&mut self.inner).map(|inner| inner.values.iter_mut())
}
}
impl<T: Integer> PartialEq for Dictionary<T> {
fn eq(&self, other: &Self) -> bool {
if Arc::ptr_eq(&self.inner, &other.inner) {
return true;
}
let a = &self.inner.values;
let b = &other.inner.values;
if a.count() != b.count() {
return false;
}
for (i, s) in a.iter() {
match b.get(i) {
Some(t) if t.as_str() == s.as_str() => {}
_ => return false,
}
}
true
}
}
impl<T: Integer> std::fmt::Debug for Dictionary<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Dictionary")
.field("len", &self.inner.values.count())
.finish()
}
}
impl<T: Integer> From<Vec64<String>> for Dictionary<T> {
fn from(values: Vec64<String>) -> Self {
Self::from_values(values)
}
}
impl<T: Integer> From<Vec<String>> for Dictionary<T> {
fn from(values: Vec<String>) -> Self {
Self::from_values(Vec64::from(values))
}
}
impl<T: Integer, S: Into<String>> FromIterator<S> for Dictionary<T> {
fn from_iter<I: IntoIterator<Item = S>>(iter: I) -> Self {
let owned: Vec64<String> = Vec64::from(iter.into_iter().map(Into::into).collect::<Vec<_>>());
Self::from_values(owned)
}
}
const DEFAULT_WIDE_CAP: usize = 1 << 20;
#[inline]
fn max_cap<T: Integer>() -> usize {
let type_max = T::max_value().to_usize().saturating_add(1);
if type_max > DEFAULT_WIDE_CAP {
DEFAULT_WIDE_CAP
} else {
type_max
}
}
#[derive(Debug, Clone)]
pub enum CategoryManagerT {
#[cfg(feature = "default_categorical_8")]
U8(Dictionary<u8>),
#[cfg(feature = "extended_categorical")]
U16(Dictionary<u16>),
#[cfg(any(not(feature = "default_categorical_8"), feature = "extended_categorical"))]
U32(Dictionary<u32>),
#[cfg(feature = "extended_categorical")]
U64(Dictionary<u64>),
}
impl CategoryManagerT {
pub fn install_from(array: &mut crate::Array) -> Option<Self> {
use crate::{Array, TextArray};
match array {
#[cfg(any(not(feature = "default_categorical_8"), feature = "extended_categorical"))]
Array::TextArray(TextArray::Categorical32(arc)) => {
let cat = Arc::make_mut(arc);
Some(CategoryManagerT::U32(cat.dictionary.clone()))
}
#[cfg(feature = "default_categorical_8")]
Array::TextArray(TextArray::Categorical8(arc)) => {
let cat = Arc::make_mut(arc);
Some(CategoryManagerT::U8(cat.dictionary.clone()))
}
#[cfg(feature = "extended_categorical")]
Array::TextArray(TextArray::Categorical16(arc)) => {
let cat = Arc::make_mut(arc);
Some(CategoryManagerT::U16(cat.dictionary.clone()))
}
#[cfg(feature = "extended_categorical")]
Array::TextArray(TextArray::Categorical64(arc)) => {
let cat = Arc::make_mut(arc);
Some(CategoryManagerT::U64(cat.dictionary.clone()))
}
_ => None,
}
}
pub fn add_remap_cat(&self, array: &mut crate::Array) {
use crate::{Array, TextArray};
match (self, array) {
#[cfg(any(not(feature = "default_categorical_8"), feature = "extended_categorical"))]
(CategoryManagerT::U32(d), Array::TextArray(TextArray::Categorical32(arc))) => {
d.add_remap_cat(Arc::make_mut(arc));
}
#[cfg(feature = "default_categorical_8")]
(CategoryManagerT::U8(d), Array::TextArray(TextArray::Categorical8(arc))) => {
d.add_remap_cat(Arc::make_mut(arc));
}
#[cfg(feature = "extended_categorical")]
(CategoryManagerT::U16(d), Array::TextArray(TextArray::Categorical16(arc))) => {
d.add_remap_cat(Arc::make_mut(arc));
}
#[cfg(feature = "extended_categorical")]
(CategoryManagerT::U64(d), Array::TextArray(TextArray::Categorical64(arc))) => {
d.add_remap_cat(Arc::make_mut(arc));
}
_ => {}
}
}
pub fn add_remap_cats<'a, I>(slot: &mut Option<Self>, chunks: I)
where
I: IntoIterator<Item = &'a mut crate::Array>,
{
for chunk in chunks {
match slot {
Some(m) => m.add_remap_cat(chunk),
None => *slot = Self::install_from(chunk),
}
}
}
pub fn len(&self) -> usize {
match self {
#[cfg(feature = "default_categorical_8")]
CategoryManagerT::U8(d) => d.len(),
#[cfg(feature = "extended_categorical")]
CategoryManagerT::U16(d) => d.len(),
#[cfg(any(not(feature = "default_categorical_8"), feature = "extended_categorical"))]
CategoryManagerT::U32(d) => d.len(),
#[cfg(feature = "extended_categorical")]
CategoryManagerT::U64(d) => d.len(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_dictionary_starts_empty() {
let d: Dictionary<u32> = Dictionary::new();
assert_eq!(d.len(), 0);
assert!(d.is_empty());
assert_eq!(d.lookup("anything"), None);
}
#[test]
fn intern_assigns_dense_sequential_codes() {
let d: Dictionary<u32> = Dictionary::new();
assert_eq!(d.add_cat("a"), Ok(0));
assert_eq!(d.add_cat("b"), Ok(1));
assert_eq!(d.add_cat("c"), Ok(2));
assert_eq!(d.add_cat("a"), Ok(0));
assert_eq!(d.len(), 3);
let values: Vec<&str> = d.values().iter().map(|s| s.as_str()).collect();
assert_eq!(values, vec!["a", "b", "c"]);
}
#[test]
fn clones_share_state() {
let d: Dictionary<u32> = Dictionary::new();
let cloned = d.clone();
assert!(d.shares_with(&cloned));
assert_eq!(d.add_cat("a"), Ok(0));
let values: Vec<&str> = cloned.values().iter().map(|s| s.as_str()).collect();
assert_eq!(values, vec!["a"]);
}
#[test]
fn detach_breaks_sharing() {
let a: Dictionary<u32> = Dictionary::new();
let _ = a.add_cat("x").unwrap();
let mut b = a.clone();
b.detach_to_owned();
assert_eq!(a.values().get(0).map(|s| s.as_str()), Some("x"));
assert_eq!(b.values().get(0).map(|s| s.as_str()), Some("x"));
assert!(!a.shares_with(&b));
let _ = b.add_cat("y").unwrap();
assert_eq!(a.len(), 1);
assert_eq!(b.len(), 2);
}
#[test]
fn is_prefix_of_recognises_prefix() {
let a: Dictionary<u32> = Dictionary::from_iter(["x", "y"]);
let b: Dictionary<u32> = Dictionary::from_iter(["x", "y", "z"]);
assert!(a.is_prefix_of(&b));
assert!(!b.is_prefix_of(&a));
let c: Dictionary<u32> = Dictionary::from_iter(["x", "z"]);
assert!(!a.is_prefix_of(&c));
}
#[test]
fn intern_returns_overflow_at_u8_cap() {
let d: Dictionary<u8> = Dictionary::new();
for i in 0..256u32 {
d.add_cat(&format!("v{i}")).unwrap();
}
assert_eq!(d.add_cat("overflow"), Err(DictionaryError::Overflow));
assert_eq!(d.len(), 256);
}
#[test]
fn concurrent_intern_under_u8_cap_no_leaks() {
use std::sync::Arc;
use std::thread;
let d: Arc<Dictionary<u8>> = Arc::new(Dictionary::new());
let mut handles = Vec::new();
for t in 0..16 {
let d = Arc::clone(&d);
handles.push(thread::spawn(move || {
let mut successes = 0u32;
let mut overflows = 0u32;
for i in 0..100 {
let s = format!("t{t}_v{i}");
match d.add_cat(&s) {
Ok(_) => successes += 1,
Err(DictionaryError::Overflow) => overflows += 1,
}
}
(successes, overflows)
}));
}
let (mut total_succ, mut total_ovf) = (0u32, 0u32);
for h in handles {
let (s, o) = h.join().unwrap();
total_succ += s;
total_ovf += o;
}
assert_eq!(d.len(), 256);
assert_eq!(total_succ, 256);
assert_eq!(total_ovf, 16 * 100 - 256);
}
#[test]
fn concurrent_intern_distinct_strings_no_duplicates() {
use std::sync::Arc;
use std::thread;
let d: Arc<Dictionary<u32>> = Arc::new(Dictionary::new());
let mut handles = Vec::new();
for t in 0..8 {
let d = Arc::clone(&d);
handles.push(thread::spawn(move || {
for i in 0..500 {
let _ = d.add_cat(&format!("t{t}_v{i}")).unwrap();
}
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(d.len(), 8 * 500);
}
}