use ahash::RandomState;
use dashmap::DashMap;
use once_cell::sync::OnceCell;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::any::{Any, TypeId};
use std::borrow::Borrow;
use std::fmt::Display;
use std::hash::{Hash, Hasher};
use std::ops::Deref;
use std::sync::Arc;
#[derive(Debug)]
pub struct ArcIntern<T: Eq + Hash + Send + Sync + 'static + ?Sized> {
arc: Arc<T>,
}
type Container<T> = DashMap<Arc<T>, (), RandomState>;
static CONTAINER: OnceCell<DashMap<TypeId, Box<dyn Any + Send + Sync>, RandomState>> =
OnceCell::new();
impl<T: Eq + Hash + Send + Sync + 'static + ?Sized> ArcIntern<T> {
fn from_arc(val: Arc<T>) -> ArcIntern<T> {
let type_map = CONTAINER.get_or_init(|| DashMap::with_hasher(RandomState::new()));
let boxed = if let Some(boxed) = type_map.get(&TypeId::of::<T>()) {
boxed
} else {
type_map
.entry(TypeId::of::<T>())
.or_insert_with(|| Box::new(Container::<T>::with_hasher(RandomState::new())))
.downgrade()
};
let m: &Container<T> = boxed.value().downcast_ref::<Container<T>>().unwrap();
let b = m.entry(val).or_insert(());
return ArcIntern {
arc: b.key().clone(),
};
}
pub fn num_objects_interned() -> usize {
if let Some(m) = CONTAINER
.get()
.and_then(|type_map| type_map.get(&TypeId::of::<T>()))
{
return m.downcast_ref::<Container<T>>().unwrap().len();
}
0
}
pub fn refcount(&self) -> usize {
Arc::strong_count(&self.arc) - 1
}
}
impl<T: Eq + Hash + Send + Sync + 'static> ArcIntern<T> {
pub fn new(val: T) -> ArcIntern<T> {
Self::from_arc(Arc::new(val))
}
}
impl<T: Eq + Hash + Send + Sync + 'static + ?Sized> Clone for ArcIntern<T> {
fn clone(&self) -> Self {
ArcIntern {
arc: self.arc.clone(),
}
}
}
impl<T: Eq + Hash + Send + Sync + ?Sized> Drop for ArcIntern<T> {
fn drop(&mut self) {
if let Some(m) = CONTAINER
.get()
.and_then(|type_map| type_map.get(&TypeId::of::<T>()))
{
let m: &Container<T> = m.downcast_ref::<Container<T>>().unwrap();
m.remove_if(&self.arc, |k, _v| {
Arc::strong_count(k) == 2
});
}
}
}
impl<T: Send + Sync + Hash + Eq + ?Sized> AsRef<T> for ArcIntern<T> {
fn as_ref(&self) -> &T {
self.arc.as_ref()
}
}
impl<T: Eq + Hash + Send + Sync + ?Sized> Borrow<T> for ArcIntern<T> {
fn borrow(&self) -> &T {
self.as_ref()
}
}
impl<T: Eq + Hash + Send + Sync + ?Sized> Deref for ArcIntern<T> {
type Target = T;
fn deref(&self) -> &T {
self.as_ref()
}
}
impl<T: Eq + Hash + Send + Sync + Display + ?Sized> Display for ArcIntern<T> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
self.deref().fmt(f)
}
}
impl<T: Eq + Hash + Send + Sync + 'static + ?Sized> From<Box<T>> for ArcIntern<T> {
fn from(b: Box<T>) -> Self {
Self::from_arc(Arc::from(b))
}
}
impl<'a, T> From<&'a T> for ArcIntern<T>
where
T: Eq + Hash + Send + Sync + 'static + ?Sized,
Arc<T>: From<&'a T>,
{
fn from(t: &'a T) -> Self {
Self::from_arc(Arc::from(t))
}
}
impl<T: Eq + Hash + Send + Sync + 'static> From<T> for ArcIntern<T> {
fn from(t: T) -> Self {
ArcIntern::new(t)
}
}
impl<T: Eq + Hash + Send + Sync + Default + 'static + ?Sized> Default for ArcIntern<T> {
fn default() -> ArcIntern<T> {
ArcIntern::new(Default::default())
}
}
impl<T: Eq + Hash + Send + Sync + ?Sized> Hash for ArcIntern<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
let borrow: &T = self.borrow();
borrow.hash(state);
}
}
impl<T: Eq + Hash + Send + Sync + ?Sized> PartialEq for ArcIntern<T> {
fn eq(&self, other: &ArcIntern<T>) -> bool {
Arc::ptr_eq(&self.arc, &other.arc)
}
}
impl<T: Eq + Hash + Send + Sync + ?Sized> Eq for ArcIntern<T> {}
impl<T: Eq + Hash + Send + Sync + PartialOrd + ?Sized> PartialOrd for ArcIntern<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.as_ref().partial_cmp(other)
}
fn lt(&self, other: &Self) -> bool {
self.as_ref().lt(other)
}
fn le(&self, other: &Self) -> bool {
self.as_ref().le(other)
}
fn gt(&self, other: &Self) -> bool {
self.as_ref().gt(other)
}
fn ge(&self, other: &Self) -> bool {
self.as_ref().ge(other)
}
}
impl<T: Eq + Hash + Send + Sync + Ord + ?Sized> Ord for ArcIntern<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.as_ref().cmp(other)
}
}
impl<T: Eq + Hash + Send + Sync + Serialize + ?Sized> Serialize for ArcIntern<T> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
self.as_ref().serialize(serializer)
}
}
impl<'de, T: Eq + Hash + Send + Sync + 'static + ?Sized + Deserialize<'de>> Deserialize<'de>
for ArcIntern<T>
{
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
T::deserialize(deserializer).map(Self::new)
}
}
#[cfg(test)]
mod tests {
use crate::ArcIntern;
use std::collections::HashMap;
use std::sync::Arc;
use std::thread;
#[test]
fn basic() {
assert_eq!(ArcIntern::new("foo"), ArcIntern::new("foo"));
assert_ne!(ArcIntern::new("foo"), ArcIntern::new("bar"));
assert_eq!(ArcIntern::<&str>::num_objects_interned(), 0);
let _interned1 = ArcIntern::new("foo".to_string());
{
let interned2 = ArcIntern::new("foo".to_string());
let interned3 = ArcIntern::new("bar".to_string());
assert_eq!(interned2.refcount(), 2);
assert_eq!(interned3.refcount(), 1);
assert_eq!(ArcIntern::<String>::num_objects_interned(), 2);
}
assert_eq!(ArcIntern::<String>::num_objects_interned(), 1);
}
#[test]
fn sorting() {
let mut interned_vals = vec![
ArcIntern::new(4),
ArcIntern::new(2),
ArcIntern::new(5),
ArcIntern::new(0),
ArcIntern::new(1),
ArcIntern::new(3),
];
interned_vals.sort();
let sorted: Vec<String> = interned_vals.iter().map(|v| format!("{}", v)).collect();
assert_eq!(&sorted.join(","), "0,1,2,3,4,5");
}
#[derive(Eq, PartialEq, Hash)]
pub struct TestStruct2(String, u64);
#[test]
fn sequential() {
for _i in 0..10_000 {
let mut interned = Vec::with_capacity(100);
for j in 0..100 {
interned.push(ArcIntern::new(TestStruct2("foo".to_string(), j)));
}
}
assert_eq!(ArcIntern::<TestStruct2>::num_objects_interned(), 0);
}
#[derive(Eq, PartialEq, Hash)]
pub struct TestStruct(String, u64, Arc<bool>);
#[test]
fn multithreading1() {
let mut thandles = vec![];
let drop_check = Arc::new(true);
for _i in 0..10 {
let t = thread::spawn({
let drop_check = drop_check.clone();
move || {
for _i in 0..100_000 {
let interned1 =
ArcIntern::new(TestStruct("foo".to_string(), 5, drop_check.clone()));
let _interned2 =
ArcIntern::new(TestStruct("bar".to_string(), 10, drop_check.clone()));
let mut m = HashMap::new();
m.insert(interned1, ());
}
}
});
thandles.push(t);
}
for h in thandles.into_iter() {
h.join().unwrap()
}
assert_eq!(Arc::strong_count(&drop_check), 1);
assert_eq!(ArcIntern::<TestStruct>::num_objects_interned(), 0);
}
#[test]
fn test_unsized() {
assert_eq!(
ArcIntern::<[usize]>::from(&[1, 2, 3][..]),
ArcIntern::from(&[1, 2, 3][..])
);
assert_ne!(
ArcIntern::<[usize]>::from(&[1, 2][..]),
ArcIntern::from(&[1, 2, 3][..])
);
assert_eq!(ArcIntern::<[usize]>::num_objects_interned(), 0);
}
}