fast_able/
unsafe_cell_type.rs

1use std::{
2    cell::UnsafeCell,
3    ops::{Add, Deref, Div, Mul, Sub},
4};
5
6/// unsafe 无锁类型, 请手动保证访问安全, api 与 AtomicCell 一致
7pub struct U<T> {
8    _inner: UnsafeCell<T>,
9}
10
11impl<T> U<T> {
12    pub const fn new(v: T) -> U<T> {
13        U {
14            _inner: UnsafeCell::new(v),
15        }
16    }
17    #[inline(always)]
18    pub fn store(&self, v: T) {
19        let s = unsafe { &mut *self._inner.get() };
20        *s = v;
21    }
22    #[inline(always)]
23    pub fn as_mut(&self) -> &mut T {
24        unsafe { &mut *self._inner.get() }
25    }
26}
27
28impl<T: Clone> Clone for U<T> {
29    fn clone(&self) -> Self {
30        Self {
31            _inner: UnsafeCell::new(self.deref().clone()),
32        }
33    }
34}
35
36impl<T: Eq> Eq for U<T> {}
37
38impl<T: PartialEq> PartialEq for U<T> {
39    fn eq(&self, other: &Self) -> bool {
40        self.deref() == other.deref()
41    }
42}
43
44impl<T: Default> Default for U<T> {
45    fn default() -> Self {
46        Self {
47            _inner: Default::default(),
48        }
49    }
50}
51
52impl<T: Clone> U<T> {
53    pub fn load(&self) -> T {
54        self.deref().clone()
55    }
56    pub fn fetch_end(&self, v: T) -> T {
57        let r = self.deref().clone();
58        self.store(v);
59        r
60    }
61}
62
63impl<T: Add<Output = T> + Clone> U<T> {
64    pub fn fetch_add(&self, v: T) -> T {
65        let r = self.deref().clone();
66        let s = unsafe { &mut *self._inner.get() };
67        *s = self.deref().clone() + v;
68        r
69    }
70}
71
72impl<T: Sub<Output = T> + Clone> U<T> {
73    pub fn fetch_sub(&self, v: T) -> T {
74        let r = self.deref().clone();
75        let s = unsafe { &mut *self._inner.get() };
76        *s = self.deref().clone() - v;
77        r
78    }
79}
80
81use core::fmt::Debug;
82impl<T: Debug> Debug for U<T> {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        f.write_fmt(format_args!("{:?}", self.deref()))
85    }
86}
87
88use core::fmt::Display;
89impl<T: Display> Display for U<T> {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        f.write_fmt(format_args!("{}", self.deref()))
92    }
93}
94
95impl<T: Add<Output = T> + Clone> Add for U<T> {
96    fn add(self, rhs: Self) -> Self::Output {
97        let v1 = unsafe { &*self._inner.get() };
98        let v2 = unsafe { &*rhs._inner.get() };
99        Self::Output::new(v1.clone() + v2.clone())
100    }
101
102    type Output = U<T>;
103}
104
105impl<T: Sub<Output = T> + Clone> Sub for U<T> {
106    fn sub(self, rhs: Self) -> Self::Output {
107        let v1 = unsafe { &*self._inner.get() };
108        let v2 = unsafe { &*rhs._inner.get() };
109        Self::Output::new(v1.clone() - v2.clone())
110    }
111
112    type Output = U<T>;
113}
114
115impl<T: Div<Output = T> + Clone> Div for U<T> {
116    fn div(self, rhs: Self) -> Self::Output {
117        let v1 = unsafe { &*self._inner.get() };
118        let v2 = unsafe { &*rhs._inner.get() };
119        Self::Output::new(v1.clone() / v2.clone())
120    }
121
122    type Output = U<T>;
123}
124
125impl<T: Mul<Output = T> + Clone> Mul for U<T> {
126    fn mul(self, rhs: Self) -> Self::Output {
127        let v1 = unsafe { &*self._inner.get() };
128        let v2 = unsafe { &*rhs._inner.get() };
129        Self::Output::new(v1.clone() * v2.clone())
130    }
131
132    type Output = U<T>;
133}
134
135// impl Serialize and Deserialize
136use serde::{Deserialize, Serialize};
137
138impl<T: Serialize> Serialize for U<T> {
139    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
140    where
141        S: serde::Serializer,
142    {
143        // 序列化内部的值
144        self.deref().serialize(serializer)
145    }
146}
147
148impl<'de, T: Deserialize<'de>> Deserialize<'de> for U<T> {
149    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
150    where
151        D: serde::Deserializer<'de>,
152    {
153        // 反序列化得到内部的值,然后包装成 U<T>
154        let value = T::deserialize(deserializer)?;
155        Ok(U::new(value))
156    }
157}
158
159unsafe impl<T: Send> Send for U<T> {}
160unsafe impl<T: Sync> Sync for U<T> {}
161
162impl<T> Deref for U<T> {
163    type Target = T;
164    fn deref(&self) -> &Self::Target {
165        unsafe { &*self._inner.get() }
166    }
167}
168
169impl<T> AsRef<T> for U<T> {
170    fn as_ref(&self) -> &T {
171        self.deref()
172    }
173}
174
175impl<T> From<T> for U<T> {
176    fn from(value: T) -> Self {
177        U {
178            _inner: UnsafeCell::new(value),
179        }
180    }
181}
182
183#[test]
184fn test() {
185    let v1 = U::new(1);
186    let v2 = 1.into();
187    let v3 = v1.clone() + v2;
188    println!("r: {}", v3);
189    assert_eq!(v3.load(), 2);
190
191    let v2 = 1.into();
192    let v3 = v1.clone() - v2;
193    println!("r: {}", v3);
194    assert_eq!(v3.load(), 0);
195
196    let v3 = v1.fetch_add(3);
197    println!("r: {}", v3);
198    assert_eq!(v3, 1);
199    assert_eq!(v1.load(), 4);
200
201    let v3 = v1.fetch_end(5);
202    println!("r: {}", v3);
203    assert_eq!(v3, 4);
204
205    let v3 = v1.load();
206    println!("r: {}", v3);
207    assert_eq!(v3, 5);
208
209    v1.store(6);
210    println!("r: {}", v1.load());
211    assert_eq!(v1.load(), 6);
212
213    v1.fetch_sub(5);
214    println!("r: {}", v1.load());
215    assert_eq!(v1.load(), 1);
216}
217
218#[test]
219fn test_mut_thread() {
220    unsafe { std::env::set_var("RUST_LOG", "debug") };
221    env_logger::init();
222
223    static V: U<usize> = U::new(0);
224    std::thread::spawn(move || loop {
225        V.load();
226    });
227    std::thread::spawn(move || loop {
228        V.load();
229    });
230    std::thread::spawn(move || loop {
231        V.load();
232    });
233    std::thread::spawn(move || loop {
234        V.load();
235    });
236    std::thread::spawn(move || loop {
237        V.load();
238    });
239    std::thread::spawn(move || loop {
240        V.load();
241    });
242
243    for i in 0..1000000 {
244        std::thread::sleep(std::time::Duration::from_millis(100));
245        let r = V.fetch_add(i);
246        debug!("loop {}: {r}", i);
247    }
248}
249
250#[test]
251fn test_serialize() {
252    // 测试序列化和反序列化
253    let v1 = U::new(42);
254    
255    // 序列化为 JSON
256    let json = serde_json::to_string(&v1).expect("Failed to serialize");
257    println!("Serialized: {}", json);
258    
259    // 反序列化
260    let v2: U<i32> = serde_json::from_str(&json).expect("Failed to deserialize");
261    println!("Deserialized: {}", v2.load());
262    
263    // 验证值是否相等
264    assert_eq!(v1.load(), v2.load());
265    
266    // 测试复杂类型
267    let v3 = U::new(vec![1, 2, 3, 4, 5]);
268    let json = serde_json::to_string(&v3).expect("Failed to serialize vec");
269    println!("Serialized vec: {}", json);
270    
271    let v4: U<Vec<i32>> = serde_json::from_str(&json).expect("Failed to deserialize vec");
272    println!("Deserialized vec: {:?}", v4.load());
273    
274    assert_eq!(v3.load(), v4.load());
275}