1use std::{
2 fmt::Display,
3 ops::Deref,
4 sync::{atomic::Ordering, Arc},
5};
6
7use crate::{ref_inner::RefInner, ref_mut::RefMut};
8
9pub struct Ref<T> {
10 pub(crate) inner: Arc<RefInner<T>>,
11}
12
13unsafe impl<T> Sync for Ref<T> {}
14unsafe impl<T> Send for Ref<T> {}
15
16impl<T> Ref<T> {
17 pub fn new(data: T) -> Self {
18 Self {
19 inner: Arc::new(RefInner::new(data)),
20 }
21 }
22
23 pub fn locked(&self) -> bool {
24 self.inner.lock.load(Ordering::Relaxed) > 0
25 }
26
27 pub fn lock(&self) {
29 while self
30 .inner
31 .lock
32 .compare_exchange_weak(0, 1, Ordering::Acquire, Ordering::Relaxed)
33 .is_err()
34 {
35 atomic_wait::wait(&self.inner.lock, 1);
36 }
37 }
38
39 pub fn unlock(&self) {
41 self.inner.lock.store(0, Ordering::Release);
42 atomic_wait::wake_one(&self.inner.lock)
43 }
44
45 pub fn mut_scope(&self, clasure: impl Fn(&mut T)) {
49 self.lock();
50 clasure(unsafe { &mut *self.inner.cell.get() });
51 self.unlock()
52 }
53
54 pub fn get_mut(&self) -> RefMut<T> {
57 self.lock();
58
59 RefMut {
60 inner: self.inner.clone(),
61 }
62 }
63}
64
65impl<T> Clone for Ref<T> {
66 fn clone(&self) -> Self {
67 Self {
68 inner: self.inner.clone(),
69 }
70 }
71}
72
73impl<T> Deref for Ref<T> {
74 type Target = T;
75
76 fn deref(&self) -> &Self::Target {
77 unsafe { &*self.inner.cell.get() }
78 }
79}
80
81impl<T> Display for Ref<T>
82where
83 T: Display,
84{
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 self.deref().fmt(f)
87 }
88}
89
90impl<T> std::fmt::Debug for Ref<T>
91where
92 T: std::fmt::Debug,
93{
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 self.deref().fmt(f)
96 }
97}
98
99#[cfg(test)]
100mod test {
101 use super::Ref;
102 use std::thread::spawn;
103
104 #[test]
105 fn threading_1() {
106 let data = Ref::new(0);
107
108 let tmp_data = data.clone();
109 let thread1 = spawn(move || {
110 let data = tmp_data;
111 for _ in 0..5000000 {
112 let mut data = data.get_mut();
113 *data += 1;
114 }
115 });
116
117 let tmp_data = data.clone();
118 let thread2 = spawn(move || {
119 let data = tmp_data;
120 for _ in 0..5000000 {
121 let mut data = data.get_mut();
122 *data += 1;
123 }
124 });
125
126 let tmp_data = data.clone();
127 let thread3 = spawn(move || {
128 let data = tmp_data;
129 for _ in 0..5000000 {
130 let mut data = data.get_mut();
131 *data += 1;
132 }
133 });
134
135 thread1.join().unwrap();
136 thread2.join().unwrap();
137 thread3.join().unwrap();
138
139 assert_eq!(*data, 15000000)
140 }
141}