#![cfg(feature = "localmap")]
use std::{
any::TypeId,
collections::{hash_map, HashMap},
fmt,
hash::{BuildHasherDefault, Hasher},
marker::PhantomData,
};
pub use crate::local_key;
#[macro_export]
macro_rules! local_key {
($(#[$m:meta])* $vis:vis static $NAME:ident : $t:ty; $($tail:tt)*) => {
local_key!(@declare $(#[$m])* ($vis) static $NAME: $t);
local_key!($($tail)*);
};
($(#[$m:meta])* $vis:vis const $NAME:ident : $t:ty; $($tail:tt)*) => {
local_key!(@declare $(#[$m])* ($vis) const $NAME: $t);
local_key!($($tail)*);
};
() => ();
(@declare $(#[$m:meta])* ($($vis:tt)*) $kw:tt $NAME:ident : $t:ty) => {
$(#[$m])*
$($vis)* $kw $NAME: $crate::localmap::LocalKey<$t> = {
fn __type_id() -> std::any::TypeId {
struct __A;
std::any::TypeId::of::<__A>()
}
$crate::localmap::LocalKey {
__type_id,
__marker: std::marker::PhantomData,
}
};
};
}
#[derive(Debug)]
pub struct LocalKey<T: Send + 'static> {
#[doc(hidden)]
pub __type_id: fn() -> TypeId,
#[doc(hidden)]
pub __marker: PhantomData<fn() -> T>,
}
impl<T: Send + 'static> LocalKey<T> {
#[inline]
fn type_id(&'static self) -> TypeId {
(self.__type_id)()
}
}
struct IdentHasher(u64);
impl Default for IdentHasher {
fn default() -> Self {
IdentHasher(0)
}
}
impl Hasher for IdentHasher {
fn finish(&self) -> u64 {
self.0
}
fn write(&mut self, bytes: &[u8]) {
for &b in bytes {
self.write_u8(b);
}
}
fn write_u8(&mut self, i: u8) {
self.0 = (self.0 << 8) | u64::from(i);
}
fn write_u64(&mut self, i: u64) {
self.0 = i;
}
}
trait Opaque: Send + 'static {}
impl<T: Send + 'static> Opaque for T {}
impl dyn Opaque {
unsafe fn downcast_ref_unchecked<T: Send + 'static>(&self) -> &T {
&*(self as *const dyn Opaque as *const T)
}
unsafe fn downcast_mut_unchecked<T: Send + 'static>(&mut self) -> &mut T {
&mut *(self as *mut dyn Opaque as *mut T)
}
}
trait BoxDowncastExt {
unsafe fn downcast_unchecked<T: Send + 'static>(self) -> Box<T>;
}
#[cfg_attr(feature = "cargo-clippy", allow(use_self))]
impl BoxDowncastExt for Box<dyn Opaque> {
unsafe fn downcast_unchecked<T: Send + 'static>(self) -> Box<T> {
Box::from_raw(Box::into_raw(self) as *mut T)
}
}
#[derive(Default)]
pub struct LocalMap {
inner: HashMap<TypeId, Box<dyn Opaque>, BuildHasherDefault<IdentHasher>>,
}
#[cfg_attr(tarpaulin, skip)]
impl fmt::Debug for LocalMap {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LocalMap").finish()
}
}
impl LocalMap {
pub fn get<T>(&self, key: &'static LocalKey<T>) -> Option<&T>
where
T: Send + 'static,
{
Some(unsafe { self.inner.get(&key.type_id())?.downcast_ref_unchecked() })
}
pub fn get_mut<T>(&mut self, key: &'static LocalKey<T>) -> Option<&mut T>
where
T: Send + 'static,
{
Some(unsafe { self.inner.get_mut(&key.type_id())?.downcast_mut_unchecked() })
}
pub fn contains_key<T>(&self, key: &'static LocalKey<T>) -> bool
where
T: Send + 'static,
{
self.inner.contains_key(&key.type_id())
}
pub fn insert<T>(&mut self, key: &'static LocalKey<T>, value: T) -> Option<T>
where
T: Send + 'static,
{
Some(unsafe {
*self
.inner
.insert(key.type_id(), Box::new(value))?
.downcast_unchecked()
})
}
pub fn remove<T>(&mut self, key: &'static LocalKey<T>) -> Option<T>
where
T: Send + 'static,
{
Some(unsafe { *self.inner.remove(&key.type_id())?.downcast_unchecked() })
}
pub fn entry<T>(&mut self, key: &'static LocalKey<T>) -> Entry<'_, T>
where
T: Send + 'static,
{
match self.inner.entry(key.type_id()) {
hash_map::Entry::Occupied(entry) => Entry::Occupied(OccupiedEntry {
inner: entry,
#[cfg_attr(tarpaulin, skip)]
_marker: PhantomData,
}),
hash_map::Entry::Vacant(entry) => Entry::Vacant(VacantEntry {
inner: entry,
#[cfg_attr(tarpaulin, skip)]
_marker: PhantomData,
}),
}
}
}
#[derive(Debug)]
pub enum Entry<'a, T: Send + 'static> {
Occupied(OccupiedEntry<'a, T>),
Vacant(VacantEntry<'a, T>),
}
impl<'a, T> Entry<'a, T>
where
T: Send + 'static,
{
#[allow(missing_docs)]
pub fn or_insert(self, default: T) -> &'a mut T {
self.or_insert_with(|| default)
}
#[allow(missing_docs)]
pub fn or_insert_with<F>(self, default: F) -> &'a mut T
where
F: FnOnce() -> T,
{
match self {
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => entry.insert(default()),
}
}
#[allow(missing_docs)]
pub fn and_modify<F>(self, f: F) -> Entry<'a, T>
where
F: FnOnce(&mut T),
{
match self {
Entry::Occupied(mut entry) => {
f(entry.get_mut());
Entry::Occupied(entry)
}
Entry::Vacant(entry) => Entry::Vacant(entry),
}
}
}
pub struct OccupiedEntry<'a, T: Send + 'static> {
inner: hash_map::OccupiedEntry<'a, TypeId, Box<dyn Opaque>>,
_marker: PhantomData<T>,
}
#[cfg_attr(tarpaulin, skip)]
impl<'a, T> fmt::Debug for OccupiedEntry<'a, T>
where
T: Send + 'static + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("OccupiedEntry").field(self.get()).finish()
}
}
#[allow(missing_docs)]
impl<'a, T> OccupiedEntry<'a, T>
where
T: Send + 'static,
{
pub fn get(&self) -> &T {
unsafe { self.inner.get().downcast_ref_unchecked() }
}
pub fn get_mut(&mut self) -> &mut T {
unsafe { self.inner.get_mut().downcast_mut_unchecked() }
}
pub fn into_mut(self) -> &'a mut T {
unsafe { self.inner.into_mut().downcast_mut_unchecked() }
}
pub fn insert(&mut self, value: T) -> T {
unsafe { *self.inner.insert(Box::new(value)).downcast_unchecked() }
}
pub fn remove(self) -> T {
unsafe { *self.inner.remove().downcast_unchecked() }
}
}
pub struct VacantEntry<'a, T: Send + 'static> {
inner: hash_map::VacantEntry<'a, TypeId, Box<dyn Opaque>>,
_marker: PhantomData<T>,
}
#[cfg_attr(tarpaulin, skip)]
impl<'a, T: Send + 'static> fmt::Debug for VacantEntry<'a, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("VacantEntry").finish()
}
}
#[allow(missing_docs)]
impl<'a, T> VacantEntry<'a, T>
where
T: Send + 'static,
{
pub fn insert(self, default: T) -> &'a mut T {
unsafe {
self.inner
.insert(Box::new(default))
.downcast_mut_unchecked()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn smoke_test() {
let mut map = LocalMap::default();
local_key! {
static KEY: String;
}
assert!(!map.contains_key(&KEY));
assert!(map.insert(&KEY, String::from("foo")).is_none());
assert!(map.contains_key(&KEY));
assert_eq!(map.get(&KEY).map(String::as_str), Some("foo"));
assert_eq!(map.insert(&KEY, String::from("bar")), Some("foo".into()));
assert!(map.contains_key(&KEY));
assert_eq!(map.get(&KEY).map(String::as_str), Some("bar"));
assert_eq!(map.remove(&KEY), Some("bar".into()));
assert!(!map.contains_key(&KEY));
}
#[test]
fn entry_or_insert() {
let mut map = LocalMap::default();
local_key! {
static KEY: String;
}
map.entry(&KEY).or_insert("foo".into());
assert_eq!(map.get(&KEY).map(String::as_str), Some("foo"));
map.entry(&KEY).or_insert("bar".into());
assert_eq!(map.get(&KEY).map(String::as_str), Some("foo"));
}
#[test]
fn entry_and_modify() {
let mut map = LocalMap::default();
local_key! {
static KEY: String;
}
map.entry(&KEY).and_modify(|s| {
*s += "foo";
});
assert!(!map.contains_key(&KEY));
map.insert(&KEY, "foo".into());
map.entry(&KEY).and_modify(|s| {
*s += "bar";
});
assert_eq!(map.get(&KEY).map(String::as_str), Some("foobar"));
map.entry(&KEY).and_modify(|s| {
*s += "baz";
});
assert_eq!(map.get(&KEY).map(String::as_str), Some("foobarbaz"));
}
#[test]
fn occupied_entry() {
let mut map = LocalMap::default();
local_key! {
static KEY: String;
}
map.insert(&KEY, "foo".into());
if let Entry::Occupied(mut entry) = map.entry(&KEY) {
assert_eq!(entry.get(), "foo");
assert_eq!(entry.insert("bar".into()), "foo");
assert_eq!(entry.get(), "bar");
assert_eq!(entry.remove(), "bar");
}
assert!(!map.contains_key(&KEY));
}
#[test]
fn local_key_const() {
local_key! {
const KEY: String;
}
let mut map = LocalMap::default();
map.insert(&KEY, "foo".into());
assert!(map.contains_key(&KEY));
}
}