fast_able/
unsafe_cell_type.rs

1use std::{
2    cell::UnsafeCell,
3    ops::{Add, Deref, Div, Mul, Sub},
4};
5
6/// Unsafe lock-free type, please manually ensure access safety, API is consistent with AtomicCell
7/// unsafe 无锁类型, 请手动保证访问安全, api 与 AtomicCell 一致
8pub struct U<T> {
9    _inner: UnsafeCell<T>,
10}
11
12impl<T> U<T> {
13    pub const fn new(v: T) -> U<T> {
14        U {
15            _inner: UnsafeCell::new(v),
16        }
17    }
18    #[inline(always)]
19    pub fn store(&self, v: T) {
20        let s = unsafe { &mut *self._inner.get() };
21        *s = v;
22    }
23    #[inline(always)]
24    pub fn as_mut(&self) -> &mut T {
25        unsafe { &mut *self._inner.get() }
26    }
27}
28
29impl<T: Clone> Clone for U<T> {
30    fn clone(&self) -> Self {
31        Self {
32            _inner: UnsafeCell::new(self.deref().clone()),
33        }
34    }
35}
36
37impl<T: Eq> Eq for U<T> {}
38
39impl<T: PartialEq> PartialEq for U<T> {
40    fn eq(&self, other: &Self) -> bool {
41        self.deref() == other.deref()
42    }
43}
44
45impl<T: Default> Default for U<T> {
46    fn default() -> Self {
47        Self {
48            _inner: Default::default(),
49        }
50    }
51}
52
53impl<T: Clone> U<T> {
54    pub fn load(&self) -> T {
55        self.deref().clone()
56    }
57    pub fn fetch_end(&self, v: T) -> T {
58        let r = self.deref().clone();
59        self.store(v);
60        r
61    }
62}
63
64impl<T: Add<Output = T> + Clone> U<T> {
65    pub fn fetch_add(&self, v: T) -> T {
66        let r = self.deref().clone();
67        let s = unsafe { &mut *self._inner.get() };
68        *s = self.deref().clone() + v;
69        r
70    }
71}
72
73impl<T: Sub<Output = T> + Clone> U<T> {
74    pub fn fetch_sub(&self, v: T) -> T {
75        let r = self.deref().clone();
76        let s = unsafe { &mut *self._inner.get() };
77        *s = self.deref().clone() - v;
78        r
79    }
80}
81
82use core::fmt::Debug;
83impl<T: Debug> Debug for U<T> {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        f.write_fmt(format_args!("{:?}", self.deref()))
86    }
87}
88
89use core::fmt::Display;
90impl<T: Display> Display for U<T> {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        f.write_fmt(format_args!("{}", self.deref()))
93    }
94}
95
96impl<T: Add<Output = T> + Clone> Add for U<T> {
97    fn add(self, rhs: Self) -> Self::Output {
98        let v1 = unsafe { &*self._inner.get() };
99        let v2 = unsafe { &*rhs._inner.get() };
100        Self::Output::new(v1.clone() + v2.clone())
101    }
102
103    type Output = U<T>;
104}
105
106impl<T: Sub<Output = T> + Clone> Sub for U<T> {
107    fn sub(self, rhs: Self) -> Self::Output {
108        let v1 = unsafe { &*self._inner.get() };
109        let v2 = unsafe { &*rhs._inner.get() };
110        Self::Output::new(v1.clone() - v2.clone())
111    }
112
113    type Output = U<T>;
114}
115
116impl<T: Div<Output = T> + Clone> Div for U<T> {
117    fn div(self, rhs: Self) -> Self::Output {
118        let v1 = unsafe { &*self._inner.get() };
119        let v2 = unsafe { &*rhs._inner.get() };
120        Self::Output::new(v1.clone() / v2.clone())
121    }
122
123    type Output = U<T>;
124}
125
126impl<T: Mul<Output = T> + Clone> Mul for U<T> {
127    fn mul(self, rhs: Self) -> Self::Output {
128        let v1 = unsafe { &*self._inner.get() };
129        let v2 = unsafe { &*rhs._inner.get() };
130        Self::Output::new(v1.clone() * v2.clone())
131    }
132
133    type Output = U<T>;
134}
135
136// impl Serialize and Deserialize
137use serde::{Deserialize, Serialize};
138
139impl<T: Serialize> Serialize for U<T> {
140    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
141    where
142        S: serde::Serializer,
143    {
144        // 序列化内部的值
145        self.deref().serialize(serializer)
146    }
147}
148
149impl<'de, T: Deserialize<'de>> Deserialize<'de> for U<T> {
150    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
151    where
152        D: serde::Deserializer<'de>,
153    {
154        // 反序列化得到内部的值,然后包装成 U<T>
155        let value = T::deserialize(deserializer)?;
156        Ok(U::new(value))
157    }
158}
159
160unsafe impl<T: Send> Send for U<T> {}
161unsafe impl<T: Sync> Sync for U<T> {}
162
163impl<T> Deref for U<T> {
164    type Target = T;
165    fn deref(&self) -> &Self::Target {
166        unsafe { &*self._inner.get() }
167    }
168}
169
170impl<T> AsRef<T> for U<T> {
171    fn as_ref(&self) -> &T {
172        self.deref()
173    }
174}
175
176impl<T> From<T> for U<T> {
177    fn from(value: T) -> Self {
178        U {
179            _inner: UnsafeCell::new(value),
180        }
181    }
182}
183
184#[test]
185fn test() {
186    let v1 = U::new(1);
187    let v2 = 1.into();
188    let v3 = v1.clone() + v2;
189    println!("r: {}", v3);
190    assert_eq!(v3.load(), 2);
191
192    let v2 = 1.into();
193    let v3 = v1.clone() - v2;
194    println!("r: {}", v3);
195    assert_eq!(v3.load(), 0);
196
197    let v3 = v1.fetch_add(3);
198    println!("r: {}", v3);
199    assert_eq!(v3, 1);
200    assert_eq!(v1.load(), 4);
201
202    let v3 = v1.fetch_end(5);
203    println!("r: {}", v3);
204    assert_eq!(v3, 4);
205
206    let v3 = v1.load();
207    println!("r: {}", v3);
208    assert_eq!(v3, 5);
209
210    v1.store(6);
211    println!("r: {}", v1.load());
212    assert_eq!(v1.load(), 6);
213
214    v1.fetch_sub(5);
215    println!("r: {}", v1.load());
216    assert_eq!(v1.load(), 1);
217}
218
219#[test]
220fn test_mut_thread() {
221    unsafe { std::env::set_var("RUST_LOG", "debug") };
222    env_logger::init();
223
224    static V: U<usize> = U::new(0);
225    std::thread::spawn(move || loop {
226        V.load();
227    });
228    std::thread::spawn(move || loop {
229        V.load();
230    });
231    std::thread::spawn(move || loop {
232        V.load();
233    });
234    std::thread::spawn(move || loop {
235        V.load();
236    });
237    std::thread::spawn(move || loop {
238        V.load();
239    });
240    std::thread::spawn(move || loop {
241        V.load();
242    });
243
244    for i in 0..1000000 {
245        std::thread::sleep(std::time::Duration::from_millis(100));
246        let r = V.fetch_add(i);
247        debug!("loop {}: {r}", i);
248    }
249}
250
251#[test]
252fn test_serialize() {
253    // 测试序列化和反序列化
254    let v1 = U::new(42);
255    
256    // 序列化为 JSON
257    let json = serde_json::to_string(&v1).expect("Failed to serialize");
258    println!("Serialized: {}", json);
259    
260    // 反序列化
261    let v2: U<i32> = serde_json::from_str(&json).expect("Failed to deserialize");
262    println!("Deserialized: {}", v2.load());
263    
264    // 验证值是否相等
265    assert_eq!(v1.load(), v2.load());
266    
267    // 测试复杂类型
268    let v3 = U::new(vec![1, 2, 3, 4, 5]);
269    let json = serde_json::to_string(&v3).expect("Failed to serialize vec");
270    println!("Serialized vec: {}", json);
271    
272    let v4: U<Vec<i32>> = serde_json::from_str(&json).expect("Failed to deserialize vec");
273    println!("Deserialized vec: {:?}", v4.load());
274    
275    assert_eq!(v3.load(), v4.load());
276}