Skip to main content

fast_able/
unsafe_cell_type.rs

1use std::{
2    cell::UnsafeCell,
3    ops::{Add, Deref, Div, Mul, Sub},
4};
5
6/// Unsafe lock-free type, please manually ensure access safety, API is consistent with AtomicCell
7/// unsafe 无锁类型, 请手动保证访问安全, api 与 AtomicCell 一致
8pub struct U<T> {
9    _inner: UnsafeCell<T>,
10}
11
12impl<T> U<T> {
13    pub const fn new(v: T) -> U<T> {
14        U {
15            _inner: UnsafeCell::new(v),
16        }
17    }
18    #[inline(always)]
19    pub fn store(&self, v: T) {
20        let s = unsafe { &mut *self._inner.get() };
21        *s = v;
22    }
23    #[inline(always)]
24    pub fn as_mut(&self) -> &mut T {
25        unsafe { &mut *self._inner.get() }
26    }
27
28    /// 消费 U<T>,返回内部值
29    #[inline(always)]
30    pub fn into_inner(self) -> T {
31        self._inner.into_inner()
32    }
33}
34
35impl<T: Clone> Clone for U<T> {
36    fn clone(&self) -> Self {
37        Self {
38            _inner: UnsafeCell::new(self.deref().clone()),
39        }
40    }
41}
42
43impl<T: Eq> Eq for U<T> {}
44
45impl<T: PartialEq> PartialEq for U<T> {
46    fn eq(&self, other: &Self) -> bool {
47        self.deref() == other.deref()
48    }
49}
50
51impl<T: Default> Default for U<T> {
52    fn default() -> Self {
53        Self {
54            _inner: Default::default(),
55        }
56    }
57}
58
59impl<T: Clone> U<T> {
60    pub fn load(&self) -> T {
61        self.deref().clone()
62    }
63    pub fn fetch_end(&self, v: T) -> T {
64        let r = self.deref().clone();
65        self.store(v);
66        r
67    }
68}
69
70impl<T: Add<Output = T> + Clone> U<T> {
71    pub fn fetch_add(&self, v: T) -> T {
72        let r = self.deref().clone();
73        let s = unsafe { &mut *self._inner.get() };
74        *s = self.deref().clone() + v;
75        r
76    }
77}
78
79impl<T: Sub<Output = T> + Clone> U<T> {
80    pub fn fetch_sub(&self, v: T) -> T {
81        let r = self.deref().clone();
82        let s = unsafe { &mut *self._inner.get() };
83        *s = self.deref().clone() - v;
84        r
85    }
86}
87
88use core::fmt::Debug;
89impl<T: Debug> Debug for U<T> {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        f.write_fmt(format_args!("{:?}", self.deref()))
92    }
93}
94
95use core::fmt::Display;
96impl<T: Display> Display for U<T> {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        f.write_fmt(format_args!("{}", self.deref()))
99    }
100}
101
102impl<T: Add<Output = T> + Clone> Add for U<T> {
103    fn add(self, rhs: Self) -> Self::Output {
104        let v1 = unsafe { &*self._inner.get() };
105        let v2 = unsafe { &*rhs._inner.get() };
106        Self::Output::new(v1.clone() + v2.clone())
107    }
108
109    type Output = U<T>;
110}
111
112impl<T: Sub<Output = T> + Clone> Sub for U<T> {
113    fn sub(self, rhs: Self) -> Self::Output {
114        let v1 = unsafe { &*self._inner.get() };
115        let v2 = unsafe { &*rhs._inner.get() };
116        Self::Output::new(v1.clone() - v2.clone())
117    }
118
119    type Output = U<T>;
120}
121
122impl<T: Div<Output = T> + Clone> Div for U<T> {
123    fn div(self, rhs: Self) -> Self::Output {
124        let v1 = unsafe { &*self._inner.get() };
125        let v2 = unsafe { &*rhs._inner.get() };
126        Self::Output::new(v1.clone() / v2.clone())
127    }
128
129    type Output = U<T>;
130}
131
132impl<T: Mul<Output = T> + Clone> Mul for U<T> {
133    fn mul(self, rhs: Self) -> Self::Output {
134        let v1 = unsafe { &*self._inner.get() };
135        let v2 = unsafe { &*rhs._inner.get() };
136        Self::Output::new(v1.clone() * v2.clone())
137    }
138
139    type Output = U<T>;
140}
141
142// impl Serialize and Deserialize
143use serde::{Deserialize, Serialize};
144
145impl<T: Serialize> Serialize for U<T> {
146    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
147    where
148        S: serde::Serializer,
149    {
150        // 序列化内部的值
151        self.deref().serialize(serializer)
152    }
153}
154
155impl<'de, T: Deserialize<'de>> Deserialize<'de> for U<T> {
156    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
157    where
158        D: serde::Deserializer<'de>,
159    {
160        // 反序列化得到内部的值,然后包装成 U<T>
161        let value = T::deserialize(deserializer)?;
162        Ok(U::new(value))
163    }
164}
165
166unsafe impl<T: Send> Send for U<T> {}
167unsafe impl<T: Sync> Sync for U<T> {}
168
169impl<T> Deref for U<T> {
170    type Target = T;
171    fn deref(&self) -> &Self::Target {
172        unsafe { &*self._inner.get() }
173    }
174}
175
176impl<T> AsRef<T> for U<T> {
177    fn as_ref(&self) -> &T {
178        self.deref()
179    }
180}
181
182impl<T> From<T> for U<T> {
183    fn from(value: T) -> Self {
184        U {
185            _inner: UnsafeCell::new(value),
186        }
187    }
188}
189
190#[test]
191fn test() {
192    let v1 = U::new(1);
193    let v2 = 1.into();
194    let v3 = v1.clone() + v2;
195    println!("r: {}", v3);
196    assert_eq!(v3.load(), 2);
197
198    let v2 = 1.into();
199    let v3 = v1.clone() - v2;
200    println!("r: {}", v3);
201    assert_eq!(v3.load(), 0);
202
203    let v3 = v1.fetch_add(3);
204    println!("r: {}", v3);
205    assert_eq!(v3, 1);
206    assert_eq!(v1.load(), 4);
207
208    let v3 = v1.fetch_end(5);
209    println!("r: {}", v3);
210    assert_eq!(v3, 4);
211
212    let v3 = v1.load();
213    println!("r: {}", v3);
214    assert_eq!(v3, 5);
215
216    v1.store(6);
217    println!("r: {}", v1.load());
218    assert_eq!(v1.load(), 6);
219
220    v1.fetch_sub(5);
221    println!("r: {}", v1.load());
222    assert_eq!(v1.load(), 1);
223}
224
225#[test]
226fn test_mut_thread() {
227    unsafe { std::env::set_var("RUST_LOG", "debug") };
228    env_logger::init();
229
230    static V: U<usize> = U::new(0);
231    std::thread::spawn(move || loop {
232        V.load();
233    });
234    std::thread::spawn(move || loop {
235        V.load();
236    });
237    std::thread::spawn(move || loop {
238        V.load();
239    });
240    std::thread::spawn(move || loop {
241        V.load();
242    });
243    std::thread::spawn(move || loop {
244        V.load();
245    });
246    std::thread::spawn(move || loop {
247        V.load();
248    });
249
250    for i in 0..1000000 {
251        std::thread::sleep(std::time::Duration::from_millis(100));
252        let r = V.fetch_add(i);
253        debug!("loop {}: {r}", i);
254    }
255}
256
257#[test]
258fn test_serialize() {
259    // 测试序列化和反序列化
260    let v1 = U::new(42);
261    
262    // 序列化为 JSON
263    let json = serde_json::to_string(&v1).expect("Failed to serialize");
264    println!("Serialized: {}", json);
265    
266    // 反序列化
267    let v2: U<i32> = serde_json::from_str(&json).expect("Failed to deserialize");
268    println!("Deserialized: {}", v2.load());
269    
270    // 验证值是否相等
271    assert_eq!(v1.load(), v2.load());
272    
273    // 测试复杂类型
274    let v3 = U::new(vec![1, 2, 3, 4, 5]);
275    let json = serde_json::to_string(&v3).expect("Failed to serialize vec");
276    println!("Serialized vec: {}", json);
277    
278    let v4: U<Vec<i32>> = serde_json::from_str(&json).expect("Failed to deserialize vec");
279    println!("Deserialized vec: {:?}", v4.load());
280    
281    assert_eq!(v3.load(), v4.load());
282}