#![warn(unsafe_op_in_unsafe_fn)]
use once_cell::sync::Lazy;
use std::any::Any;
use std::collections::HashSet;
use std::mem::MaybeUninit;
use std::ops::ControlFlow;
use std::ops::Deref;
use std::ptr;
use std::sync::atomic::AtomicPtr;
use std::sync::atomic::Ordering::SeqCst;
use std::sync::Arc;
use std::sync::Weak;
pub struct WeakList<S: WeakListNodeSet> {
head: *const RawArcNode,
tail: *const RawArcNode,
node_set: S,
}
pub unsafe trait WeakListNodeSet {
fn insert_ns(&mut self, ptr: usize);
fn contains_ns(&self, list: &WeakList<Self>, ptr: usize) -> bool
where
Self: Sized;
fn remove_ns(&mut self, ptr: usize);
}
pub type WeakListHashSet = Lazy<HashSet<usize>>;
unsafe impl WeakListNodeSet for Lazy<HashSet<usize>> {
fn insert_ns(&mut self, ptr: usize) {
self.insert(ptr);
}
fn contains_ns(&self, _list: &WeakList<Self>, ptr: usize) -> bool {
self.contains(&ptr)
}
fn remove_ns(&mut self, ptr: usize) {
self.remove(&ptr);
}
}
#[derive(Default)]
pub struct LinearSearch;
unsafe impl WeakListNodeSet for LinearSearch {
fn insert_ns(&mut self, _ptr: usize) {}
fn contains_ns(&self, list: &WeakList<Self>, ptr: usize) -> bool {
let mut node = list.head;
while !node.is_null() {
if node as usize == ptr {
return true;
}
unsafe { node = read_raw_node_next(node) };
}
false
}
fn remove_ns(&mut self, _ptr: usize) {}
}
unsafe impl<S: Send + WeakListNodeSet> Send for WeakList<S> {}
unsafe impl<S: Sync + WeakListNodeSet> Sync for WeakList<S> {}
impl<S: Default + WeakListNodeSet> Default for WeakList<S> {
fn default() -> Self {
Self {
head: ptr::null(),
tail: ptr::null(),
node_set: S::default(),
}
}
}
impl<S: WeakListNodeSet> Drop for WeakList<S> {
fn drop(&mut self) {
self.clear();
}
}
impl WeakList<Lazy<HashSet<usize>>> {
pub const fn new() -> Self {
Self {
head: ptr::null(),
tail: ptr::null(),
node_set: Lazy::new(HashSet::new),
}
}
pub fn realloc_hashset_if_needed_no_alloc(
&mut self,
bigger_hashset: Option<&mut AllocHashSet>,
) {
if self.node_set.len() != self.node_set.capacity() {
return;
}
match bigger_hashset {
Some(new_hs) => {
if new_hs.capacity() <= self.node_set.capacity() {
panic!("New AllocHashSet capacity must be greater than the current capacity but {} <= {}", new_hs.capacity(), self.node_set.capacity());
}
new_hs.0.extend(self.node_set.drain());
std::mem::swap(&mut *self.node_set, &mut new_hs.0);
}
None => {
}
}
}
pub fn hashset_capacity(&self) -> usize {
self.node_set.capacity()
}
}
impl<S: WeakListNodeSet> WeakList<S> {
pub fn push_front<T: Send + Sync + 'static>(&mut self, elem: T) -> WeakRef<T> {
self.push_front_no_alloc(elem, AllocMem::default())
}
pub fn push_front_no_alloc<T: Send + Sync + 'static>(
&mut self,
elem: T,
memory: AllocMem<T>,
) -> WeakRef<T> {
let f_move_out = arc_from_raw_to_arc_any::<T>;
let meta = RawArcNode {
prev: AtomicPtr::new(ptr::null_mut()),
next: AtomicPtr::new(ptr::null_mut()),
f_move_out,
};
let mut uninit_arc: Arc<MaybeUninit<Node<T>>> = memory.0;
let node = Node { meta, elem };
Arc::get_mut(&mut uninit_arc).unwrap().write(node);
let arc_node: Arc<Node<T>> =
unsafe { Arc::from_raw(Arc::into_raw(uninit_arc) as *const Node<T>) };
let weak_ref = WeakRef {
weak: Arc::downgrade(&arc_node),
};
let raw_node_ptr = Arc::into_raw(arc_node) as *const RawArcNode;
unsafe { self.push_front_node(raw_node_ptr) }
weak_ref
}
unsafe fn push_front_node(&mut self, raw_node_ptr: *const RawArcNode) {
match (self.head.is_null(), self.tail.is_null()) {
(true, true) => {
self.head = raw_node_ptr;
self.tail = self.head;
}
(false, false) => {
unsafe {
raw_nodes_link(raw_node_ptr, self.head);
}
self.head = raw_node_ptr;
}
_ => unreachable!("head and tail must both be null or both be not null"),
}
self.node_set.insert_ns(raw_node_ptr as usize);
}
fn contains_node(&self, raw_node_ptr: *const RawArcNode) -> bool {
self.node_set.contains_ns(self, raw_node_ptr as usize)
}
unsafe fn remove_node(&mut self, raw_node_ptr: *const RawArcNode) {
unsafe {
if self.head == raw_node_ptr {
self.head = read_raw_node_next(raw_node_ptr);
}
if self.tail == raw_node_ptr {
self.tail = read_raw_node_prev(raw_node_ptr);
}
remove_node_from_list(raw_node_ptr);
self.node_set.remove_ns(raw_node_ptr as usize);
}
}
unsafe fn move_node_to_front(&mut self, raw_node_ptr: *const RawArcNode) {
if self.head == raw_node_ptr {
return;
}
unsafe {
if self.tail == raw_node_ptr {
self.tail = read_raw_node_prev(raw_node_ptr);
}
remove_node_from_list(raw_node_ptr);
}
unsafe {
raw_nodes_link(raw_node_ptr, self.head);
}
self.head = raw_node_ptr;
}
pub fn pop_back(&mut self) -> Option<Arc<dyn Any + Send + Sync + 'static>> {
if self.tail.is_null() {
return None;
}
unsafe {
let tail = self.tail;
self.remove_node(tail);
let arc_any = move_out_of_raw_node(tail);
Some(arc_any)
}
}
pub fn pop_lru(&mut self) -> Option<Arc<dyn Any + Send + Sync + 'static>> {
let mut node = self.tail;
while !node.is_null() {
unsafe {
let strong_count = read_raw_node_strong_count(node);
if strong_count == 1 {
let n2 = node;
let n3 = read_raw_node_next(n2);
if !n3.is_null() {
raw_nodes_cut(n2, n3);
let n4 = self.tail;
let n0 = self.head;
raw_nodes_link(n4, n0);
self.head = n3;
self.tail = n2;
}
return self.pop_back();
}
node = read_raw_node_prev(node);
}
}
None
}
pub fn remove_unreachable(&mut self) -> Vec<Arc<dyn Any + Send + Sync + 'static>> {
let mut v = vec![];
self.remove_unreachable_into_f(|arc_any| {
v.push(arc_any);
ControlFlow::Continue(())
});
v
}
pub fn remove_unreachable_into_buf(
&mut self,
buf: &mut [Option<Arc<dyn Any + Send + Sync + 'static>>],
) -> usize {
if buf.is_empty() {
return 0;
}
let mut count = 0;
self.remove_unreachable_into_f(|arc_any| {
buf[count] = Some(arc_any);
count += 1;
if count == buf.len() {
ControlFlow::Break(())
} else {
ControlFlow::Continue(())
}
});
count
}
pub fn remove_unreachable_into_f<F>(&mut self, mut f: F)
where
F: FnMut(Arc<dyn Any + Send + Sync + 'static>) -> ControlFlow<()>,
{
let mut node = self.tail;
while !node.is_null() {
unsafe {
let next_node = read_raw_node_prev(node);
if raw_node_arc_is_unique(node) {
self.remove_node(node);
let arc_any = move_out_of_raw_node(node);
if f(arc_any).is_break() {
break;
}
}
node = next_node;
}
}
}
pub fn remove<T>(&mut self, weak_ref: &WeakRef<T>) -> Option<ArcRef<T>> {
let raw_node_ptr = weak_ref.weak.as_ptr() as *const RawArcNode;
if !self.contains_node(raw_node_ptr) {
return None;
}
unsafe { self.remove_node(raw_node_ptr) };
let arc: Arc<Node<T>> = unsafe { Arc::from_raw(raw_node_ptr as *const Node<T>) };
Some(ArcRef { arc })
}
pub fn clear(&mut self) {
while let Some(arc) = self.pop_back() {
drop(arc);
}
}
}
pub struct AllocMem<T>(Arc<MaybeUninit<Node<T>>>);
impl<T> Default for AllocMem<T> {
fn default() -> Self {
Self(Arc::new(MaybeUninit::uninit()))
}
}
pub struct AllocHashSet(HashSet<usize>);
unsafe impl Send for AllocHashSet {}
unsafe impl Sync for AllocHashSet {}
impl AllocHashSet {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self(HashSet::new())
}
pub fn with_capacity(capacity: usize) -> Self {
Self(HashSet::with_capacity(capacity))
}
pub fn capacity(&self) -> usize {
self.0.capacity()
}
pub fn allocate_capacity(&mut self, target_cap: usize) {
self.0.reserve(target_cap.saturating_sub(self.0.len()))
}
}
unsafe fn arc_from_raw_to_arc_any<T: Send + Sync + 'static>(
raw_node_ptr: *const RawArcNode,
) -> Arc<dyn Any + Send + Sync + 'static> {
unsafe { Arc::<Node<T>>::from_raw(raw_node_ptr as *const Node<T>) }
}
unsafe fn read_raw_node_prev(p: *const RawArcNode) -> *const RawArcNode {
unsafe { &(*p).prev }.load(SeqCst)
}
unsafe fn read_raw_node_next(p: *const RawArcNode) -> *const RawArcNode {
unsafe { &(*p).next }.load(SeqCst)
}
unsafe fn update_raw_node_prev(p: *const RawArcNode, prev: *const RawArcNode) {
unsafe { &(*p).prev }.store(prev as *mut RawArcNode, SeqCst);
}
unsafe fn update_raw_node_next(p: *const RawArcNode, next: *const RawArcNode) {
unsafe { &(*p).next }.store(next as *mut RawArcNode, SeqCst);
}
unsafe fn raw_nodes_cut(p0: *const RawArcNode, p1: *const RawArcNode) {
unsafe {
update_raw_node_next(p0, ptr::null());
update_raw_node_prev(p1, ptr::null());
}
}
unsafe fn raw_nodes_link(p0: *const RawArcNode, p1: *const RawArcNode) {
unsafe {
update_raw_node_next(p0, p1);
update_raw_node_prev(p1, p0);
}
}
unsafe fn remove_node_from_list(p: *const RawArcNode) {
unsafe {
let prev = read_raw_node_prev(p);
let next = read_raw_node_next(p);
if !prev.is_null() {
update_raw_node_next(prev, next);
}
if !next.is_null() {
update_raw_node_prev(next, prev);
}
update_raw_node_prev(p, ptr::null());
update_raw_node_next(p, ptr::null());
}
}
unsafe fn raw_node_arc_is_unique(p: *const RawArcNode) -> bool {
let mut dummy_arc: Arc<Node<()>> = unsafe { Arc::from_raw(p as *const Node<()>) };
let is_unique = Arc::get_mut(&mut dummy_arc).is_some();
std::mem::forget(dummy_arc);
is_unique
}
unsafe fn read_raw_node_strong_count(p: *const RawArcNode) -> usize {
let dummy_arc: Arc<Node<()>> = unsafe { Arc::from_raw(p as *const Node<()>) };
let strong_count = Arc::strong_count(&dummy_arc);
std::mem::forget(dummy_arc);
strong_count
}
unsafe fn move_out_of_raw_node(node: *const RawArcNode) -> Arc<dyn Any + Send + Sync + 'static> {
let arc_from_raw_to_arc_any = unsafe { (*node).f_move_out };
unsafe { arc_from_raw_to_arc_any(node) }
}
#[repr(C)]
struct RawArcNode {
prev: AtomicPtr<RawArcNode>,
next: AtomicPtr<RawArcNode>,
f_move_out: unsafe fn(*const RawArcNode) -> Arc<dyn Any + Send + Sync + 'static>,
}
#[repr(C)]
struct Node<T: ?Sized> {
meta: RawArcNode,
elem: T,
}
pub struct WeakRef<T: ?Sized> {
weak: Weak<Node<T>>,
}
impl<T: ?Sized> Clone for WeakRef<T> {
fn clone(&self) -> Self {
Self {
weak: self.weak.clone(),
}
}
}
impl<T: ?Sized + Send + Sync + 'static> WeakRef<T> {
pub fn upgrade<S: WeakListNodeSet>(&self, list: &mut WeakList<S>) -> Option<ArcRef<T>> {
let ret = self.upgrade_quietly()?;
let raw_node_ptr = self.weak.as_ptr() as *const RawArcNode;
if !list.contains_node(raw_node_ptr) {
return None;
}
unsafe {
list.move_node_to_front(raw_node_ptr);
}
Some(ret)
}
pub fn upgrade_quietly(&self) -> Option<ArcRef<T>> {
self.weak.upgrade().map(|arc| ArcRef { arc })
}
}
pub struct ArcRef<T: ?Sized> {
arc: Arc<Node<T>>,
}
impl<T: ?Sized> Clone for ArcRef<T> {
fn clone(&self) -> Self {
Self {
arc: self.arc.clone(),
}
}
}
impl<T: ?Sized> Deref for ArcRef<T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
&self.arc.elem
}
}
impl<T: ?Sized> ArcRef<T> {
pub fn get_mut(this: &mut Self) -> Option<&mut T> {
Arc::get_mut(&mut this.arc).map(|x| &mut x.elem)
}
pub fn downgrade(this: &Self) -> WeakRef<T> {
WeakRef {
weak: Arc::downgrade(&this.arc),
}
}
}
pub fn arc_dyn_any_to_arc_ref_t<T: Send + Sync + 'static>(
aa: Arc<dyn Any + Send + Sync + 'static>,
) -> Result<ArcRef<T>, Arc<dyn Any + Send + Sync + 'static>> {
aa.downcast::<Node<T>>().map(|node| ArcRef { arc: node })
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn push_elem() {
let mut wl = WeakList::new();
let ws1 = wl.push_front(format!("string1"));
let s1 = ws1.upgrade(&mut wl).expect("s1 died");
assert_eq!(*s1, format!("string1"));
}
#[test]
fn push_pop_back() {
let mut wl = WeakList::new();
let ws1 = wl.push_front(format!("string1"));
wl.pop_back();
assert!(ws1.upgrade(&mut wl).is_none());
}
#[test]
fn push_pop_lru() {
let mut wl = WeakList::new();
let ws1 = wl.push_front(format!("string1"));
assert!(wl.pop_lru().is_some());
assert!(ws1.upgrade(&mut wl).is_none());
}
#[test]
fn push_pop_lru_while_upgraded() {
let mut wl = WeakList::new();
let ws1 = wl.push_front(format!("string1"));
let _s1 = ws1.upgrade(&mut wl).unwrap();
assert!(wl.pop_lru().is_none());
}
#[test]
fn push_pop_lru_moves_upgraded_tail_to_front() {
let mut wl = WeakList::new();
let ws1 = wl.push_front(format!("string1"));
let s1 = ws1.upgrade(&mut wl).unwrap();
let _ws2 = wl.push_front(format!("string2"));
let _ws3 = wl.push_front(format!("string3"));
assert!(wl.pop_lru().is_some());
drop(s1);
assert!(wl.pop_lru().is_some());
let _s1 = ws1.upgrade(&mut wl).unwrap();
}
#[test]
fn push_push_upgrade_pop() {
let mut wl = WeakList::new();
let ws1 = wl.push_front(format!("string1"));
let ws2 = wl.push_front(format!("string2"));
assert!(ws1.upgrade(&mut wl).is_some());
wl.pop_lru();
assert!(ws1.upgrade(&mut wl).is_some());
assert!(ws2.upgrade(&mut wl).is_none());
}
#[test]
fn remove_unreachable() {
let mut wl = WeakList::new();
let ws1 = wl.push_front(format!("string1"));
let ws2 = wl.push_front(format!("string2"));
let ws3 = wl.push_front(format!("string3"));
drop(ws1);
drop(ws3);
assert_eq!(wl.remove_unreachable().len(), 2);
assert!(ws2.upgrade(&mut wl).is_some());
assert!(wl.pop_back().is_some());
assert!(ws2.upgrade(&mut wl).is_none());
assert!(wl.pop_back().is_none());
}
#[test]
fn remove_unreachable_empty_buf() {
let mut wl = WeakList::new();
let ws1 = wl.push_front(format!("string1"));
assert_eq!(wl.remove_unreachable_into_buf(&mut []), 0);
let _s1 = ws1.upgrade(&mut wl).unwrap();
}
#[test]
fn remove_unreachable_big_buf() {
let mut wl = WeakList::new();
let ws1 = wl.push_front(format!("string1"));
drop(ws1);
let mut buf = vec![None; 10];
assert_eq!(wl.remove_unreachable_into_buf(&mut buf), 1);
assert!(wl.pop_back().is_none());
}
#[test]
fn heterogenous_list() {
let mut wl = WeakList::new();
let _ws1 = wl.push_front(format!("string1"));
let _ws2 = wl.push_front(8u32);
let _ws3 = wl.push_front(vec!["a", "b", "c"]);
}
#[test]
fn weak_list_is_sync_and_send() {
fn assert_is_sync_and_send<T: Send + Sync>(_x: &T) {}
assert_is_sync_and_send(&WeakList::new());
}
#[test]
fn upgrade_node_with_another_list() {
let mut wl1 = WeakList::new();
let _ws1 = wl1.push_front(format!("string1"));
let mut wl2 = WeakList::new();
let ws2 = wl2.push_front(format!("string2"));
assert!(ws2.upgrade(&mut wl1).is_none());
}
#[test]
fn push_pop_back_updates_head() {
let mut wl = WeakList::new();
let ws1 = wl.push_front(format!("string1"));
wl.pop_back();
assert!(ws1.upgrade(&mut wl).is_none());
let ws2 = wl.push_front(format!("string2"));
assert!(ws2.upgrade(&mut wl).is_some());
}
#[test]
fn remove_node_twice() {
let mut wl = WeakList::new();
let ws1 = wl.push_front(format!("string1"));
let _s1 = wl.remove(&ws1).unwrap();
assert!(wl.remove(&ws1).is_none());
}
#[test]
fn remove_node_after_moving_list() {
let mut wl = WeakList::new();
let ws1 = wl.push_front(format!("string1"));
let mut wl2 = WeakList::new();
std::mem::swap(&mut wl, &mut wl2);
let _s1 = wl2.remove(&ws1).unwrap();
}
#[test]
fn upgrade_node_after_moving_list() {
let mut wl = WeakList::new();
let ws1 = wl.push_front(format!("string1"));
let mut wl2 = WeakList::new();
std::mem::swap(&mut wl, &mut wl2);
let _s1 = ws1.upgrade(&mut wl2).unwrap();
}
#[test]
fn remove_node_from_another_list() {
let mut wl1 = WeakList::new();
let ws1 = wl1.push_front(format!("string1"));
let mut wl2 = WeakList::new();
let _ws2 = wl2.push_front(format!("string2"));
assert!(wl2.remove(&ws1).is_none());
}
#[test]
fn list_cleared_on_drop() {
let mut wl = WeakList::new();
let ws1 = wl.push_front(format!("string1"));
drop(wl);
assert_eq!(ws1.upgrade_quietly().as_deref(), None);
}
#[test]
fn upgrade_node_after_removing_from_list() {
let mut wl = WeakList::new();
let ws1 = wl.push_front(format!("string1"));
let _arc_s1 = wl.pop_back().unwrap();
assert!(ws1.upgrade(&mut wl).is_none());
}
#[test]
fn push_pop_back_recover_type_from_pop() {
let mut wl = WeakList::new();
let ws1 = wl.push_front(format!("string1"));
let arc_dyn_any = wl.pop_back().unwrap();
drop(ws1);
let mut as1: ArcRef<String> = arc_dyn_any_to_arc_ref_t(arc_dyn_any).unwrap();
let s1_ref = ArcRef::get_mut(&mut as1).unwrap();
let s1 = std::mem::take(s1_ref);
assert_eq!(s1, format!("string1"));
}
#[test]
fn push_pop_back_recover_type_from_weak_ref() {
let mut wl = WeakList::new();
let ws1 = wl.push_front(format!("string1"));
let _arc_s1 = wl.pop_back().unwrap();
assert!(ws1.upgrade(&mut wl).is_none());
}
#[test]
fn fuzz_1() {
let mut wl = WeakList::new();
let mut weaks: Vec<Option<WeakRef<Vec<u8>>>> = vec![];
let mut upgrades = vec![];
weaks.push(Some(wl.push_front(Vec::with_capacity(91))));
upgrades.push(weaks[0].as_ref().unwrap().upgrade(&mut wl));
weaks.push(Some(wl.push_front(Vec::with_capacity(0))));
wl.clear();
upgrades.push(weaks[0].as_ref().unwrap().upgrade(&mut wl));
}
#[test]
fn variance_remove() {
let mut wl = WeakList::new();
let ws1: WeakRef<&'static str> = wl.push_front("string1");
fn shorten_lifetime<'a, T: ?Sized>(
weak_ref: WeakRef<&'a T>,
_lifetime: &'a T,
) -> WeakRef<&'a T> {
weak_ref
}
let stack_str: &str = &format!("hi");
let shorter_ws1 = shorten_lifetime(ws1, stack_str);
let s1: ArcRef<&str> = wl.remove(&shorter_ws1).unwrap();
assert_eq!(&*s1, &"string1");
}
#[test]
fn default_impl_lazy() {
let mut wl: WeakList<WeakListHashSet> = WeakList::default();
let ws1 = wl.push_front(format!("string1"));
let s1 = ws1.upgrade(&mut wl).expect("s1 died");
assert_eq!(*s1, format!("string1"));
}
}