use crate::{
loom::*,
ref_count::{Interned, Interner},
};
use std::{
borrow::Borrow,
cmp::Ordering,
collections::BTreeSet,
fmt::{Debug, Display, Formatter, Pointer},
hash::Hasher,
ops::Deref,
sync::Arc,
};
pub struct OrdInterner<T: ?Sized + std::cmp::Ord> {
inner: Arc<Ord<T>>,
}
impl<T: ?Sized + std::cmp::Ord> Clone for OrdInterner<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
#[repr(C)]
pub struct Ord<T: ?Sized + std::cmp::Ord> {
set: RwLock<BTreeSet<InternedOrd<T>>>,
}
#[cfg(loom)]
impl<T: ?Sized> Drop for Inner<T> {
fn drop(&mut self) {
self.set.read();
}
}
unsafe impl<T: ?Sized + std::cmp::Ord + Sync + Send> Send for Ord<T> {}
unsafe impl<T: ?Sized + std::cmp::Ord + Sync + Send> Sync for Ord<T> {}
impl<T: ?Sized + std::cmp::Ord> Interner for Ord<T> {
type T = T;
fn remove(&self, value: &Interned<Self>) -> (bool, Option<Interned<Self>>) {
let value = cast(value);
let mut set = self.set.write();
#[cfg(loom)]
let mut set = set.unwrap();
if let Some(i) = set.take(value) {
if i.ref_count() == 1 {
(true, Some(i.0))
} else {
set.insert(i);
(false, None)
}
} else {
(true, None)
}
}
}
impl<T: ?Sized + std::cmp::Ord> OrdInterner<T> {
pub fn new() -> Self {
Self {
inner: Arc::new(Ord {
set: RwLock::new(BTreeSet::new()),
}),
}
}
pub fn len(&self) -> usize {
let set = self.inner.set.read();
#[cfg(loom)]
let set = set.unwrap();
set.len()
}
pub fn is_empty(&self) -> bool {
let set = self.inner.set.read();
#[cfg(loom)]
let set = set.unwrap();
set.is_empty()
}
fn intern<U, F>(&self, value: U, intern: F) -> InternedOrd<T>
where
F: FnOnce(U) -> InternedOrd<T>,
U: Borrow<T>,
{
#[cfg(not(loom))]
let set = self.inner.set.upgradable_read();
#[cfg(loom)]
let set = self.inner.set.read().unwrap();
if let Some(entry) = set.get(value.borrow()) {
return entry.clone();
}
#[cfg(not(loom))]
let mut set = RwLockUpgradableReadGuard::upgrade(set);
#[cfg(loom)]
let mut set = {
drop(set);
self.inner.set.write().unwrap()
};
if let Some(entry) = set.get(value.borrow()) {
return entry.clone();
}
let mut ret = intern(value);
ret.0.make_hot(&self.inner);
set.insert(ret.clone());
ret
}
pub fn intern_ref(&self, value: &T) -> InternedOrd<T>
where
T: ToOwned,
T::Owned: Into<Box<T>>,
{
self.intern(value, |v| {
InternedOrd(Interned::from_box(v.to_owned().into()))
})
}
pub fn intern_box(&self, value: Box<T>) -> InternedOrd<T> {
self.intern(value, |v| InternedOrd(Interned::from_box(v)))
}
pub fn intern_sized(&self, value: T) -> InternedOrd<T>
where
T: Sized,
{
self.intern(value, |v| InternedOrd(Interned::from_sized(v)))
}
}
impl<T: ?Sized + std::cmp::Ord> Default for OrdInterner<T> {
fn default() -> Self {
Self::new()
}
}
#[repr(transparent)] pub struct InternedOrd<T: ?Sized + std::cmp::Ord>(Interned<Ord<T>>);
impl<T: ?Sized + std::cmp::Ord> InternedOrd<T> {
pub fn ref_count(&self) -> u32 {
self.0.ref_count()
}
}
impl<T: ?Sized + std::cmp::Ord> Clone for InternedOrd<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
fn cast<T: ?Sized + std::cmp::Ord>(i: &Interned<Ord<T>>) -> &InternedOrd<T> {
unsafe { &*(i as *const _ as *const InternedOrd<T>) }
}
impl<T: ?Sized + std::cmp::Ord> PartialEq for InternedOrd<T>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.0 == other.0 || self.deref().eq(other.deref())
}
}
impl<T: ?Sized + std::cmp::Ord> Eq for InternedOrd<T> where T: Eq {}
impl<T: ?Sized + std::cmp::Ord> PartialOrd for InternedOrd<T>
where
T: PartialOrd,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
if self.0 == other.0 {
return Some(Ordering::Equal);
}
self.deref().partial_cmp(other.deref())
}
}
impl<T: ?Sized + std::cmp::Ord> std::cmp::Ord for InternedOrd<T> {
fn cmp(&self, other: &Self) -> Ordering {
if self.0 == other.0 {
return Ordering::Equal;
}
self.deref().cmp(other.deref())
}
}
impl<T: ?Sized + std::cmp::Ord> std::hash::Hash for InternedOrd<T>
where
T: std::hash::Hash,
{
fn hash<H: Hasher>(&self, state: &mut H) {
self.deref().hash(state)
}
}
impl<T: ?Sized + std::cmp::Ord> Borrow<T> for InternedOrd<T> {
fn borrow(&self) -> &T {
self.deref()
}
}
impl<T: ?Sized + std::cmp::Ord> Deref for InternedOrd<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.0.deref()
}
}
impl<T: ?Sized + std::cmp::Ord> AsRef<T> for InternedOrd<T> {
fn as_ref(&self) -> &T {
self.deref()
}
}
impl<T: ?Sized + std::cmp::Ord> Debug for InternedOrd<T>
where
T: Debug,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Interned({:?})", &**self)
}
}
impl<T: ?Sized + std::cmp::Ord> Display for InternedOrd<T>
where
T: Display,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.deref().fmt(f)
}
}
impl<T: ?Sized + std::cmp::Ord> Pointer for InternedOrd<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Pointer::fmt(&(&**self as *const T), f)
}
}
#[cfg(feature = "serde")]
impl<T: ?Sized + std::cmp::Ord + serde::Serialize> serde::Serialize for InternedOrd<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
(**self).serialize(serializer)
}
}
#[cfg(feature = "serde")]
impl<'de, T> serde::Deserialize<'de> for InternedOrd<T>
where
T: std::cmp::Ord + serde::Deserialize<'de> + Send + Sync + 'static,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = T::deserialize(deserializer)?;
Ok(crate::global::ord_interner().intern_sized(value))
}
}
#[cfg(feature = "serde")]
impl<'de, T> serde::Deserialize<'de> for InternedOrd<[T]>
where
T: std::cmp::Ord + serde::Deserialize<'de> + Send + Sync + 'static,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = Vec::<T>::deserialize(deserializer)?;
Ok(crate::global::ord_interner().intern_box(value.into()))
}
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for InternedOrd<str> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = String::deserialize(deserializer)?;
Ok(crate::global::ord_interner().intern_box(value.into()))
}
}
#[test]
fn size() {
let s = std::mem::size_of::<Ord<()>>();
assert!(s < 100, "too big: {}", s);
}
#[test]
fn debug() {
let interner = OrdInterner::new();
let i = interner.intern_ref("value");
assert_eq!(format!("{i:?}"), r#"Interned("value")"#);
}
#[cfg(all(test, loom))]
mod tests {
use super::*;
use ::loom::{model, thread::spawn};
fn counts<T>(weak: Weak<T>) -> (usize, usize) {
unsafe {
let ptr = &weak as *const _ as *const *const (usize, usize);
**ptr
}
}
#[test]
fn drop_interner() {
model(|| {
let i = OrdInterner::new();
let i2 = Arc::downgrade(&i.inner);
let n = i.intern_box(42.into());
let h = spawn(move || drop(i));
let h2 = spawn(move || drop(n));
h.join().unwrap();
h2.join().unwrap();
assert_eq!(counts(i2), (0, 1));
})
}
#[test]
fn drop_two_external() {
model(|| {
let i = OrdInterner::new();
let i2 = Arc::downgrade(&i.inner);
let n = i.intern_box(42.into());
let n2 = n.clone();
drop(i);
let h = spawn(move || drop(n));
let h2 = spawn(move || drop(n2));
h.join().unwrap();
h2.join().unwrap();
assert_eq!(counts(i2), (0, 1));
})
}
#[test]
fn drop_against_intern() {
model(|| {
let i = OrdInterner::new();
let i2 = Arc::downgrade(&i.inner);
let n = i.intern_box(42.into());
let h1 = spawn(move || drop(n));
let h2 = spawn(move || i.intern_box(42.into()));
h1.join().unwrap();
h2.join().unwrap();
assert_eq!(counts(i2), (0, 1));
})
}
#[test]
fn tree_drop_against_intern_and_interner() {
model(|| {
let i = OrdInterner::new();
let i2 = Arc::downgrade(&i.inner);
let n = i.intern_box(42.into());
let ii = i.clone();
println!("{:?} setup", current().id());
let h1 = spawn(move || drop(n));
let h2 = spawn(move || i.intern_box(42.into()));
let h3 = spawn(move || drop(ii));
println!("{:?} joining", current().id());
h1.join().unwrap();
h2.join().unwrap();
h3.join().unwrap();
assert_eq!(counts(i2), (0, 1));
println!("{:?} done", current().id());
})
}
}