rust2go/
slot.rs

1// Copyright 2024 ihciah. All Rights Reserved.
2
3use std::{
4    mem::MaybeUninit,
5    ptr::NonNull,
6    sync::{
7        atomic::{
8            AtomicU8,
9            Ordering::{AcqRel, Acquire},
10        },
11        Mutex,
12    },
13    task::Waker,
14};
15
16/// Create a pair of SlotReader and SlotWriter.
17/// There's 2 reasons to use it when async rust to go ffi(Go holds writer and rust holds reader):
18/// 1. Rust cannot guarantee trying read before go write.
19/// 2. Rust can dealloc the memory before go write by simply drop it if using a Box directly.
20#[inline]
21pub fn new_atomic_slot<T, A>() -> (SlotReader<T, A>, SlotWriter<T, A>) {
22    let inner = SlotInner::new();
23    let ptr = unsafe { NonNull::new_unchecked(Box::into_raw(Box::new(inner))) };
24    (SlotReader(ptr), SlotWriter(ptr))
25}
26
27struct SlotInner<T, A = ()> {
28    state: State,
29    data: MaybeUninit<T>,
30    attachment: Option<A>,
31    waker: Mutex<Option<Waker>>,
32}
33
34impl<T, A> Drop for SlotInner<T, A> {
35    fn drop(&mut self) {
36        if self.state.load() & 0b100 != 0 {
37            unsafe { self.data.assume_init_drop() };
38        }
39    }
40}
41
42// 0b00x: x=1 means writer is dropped, x=0 means writer is alive.
43// 0b0x0: x=1 means reader is dropped, x=0 means reader is alive.
44// 0bx00: x=1 means data is written, x=0 means data is not written.
45#[repr(transparent)]
46struct State(AtomicU8);
47
48impl State {
49    // Load with Acquire.
50    #[inline]
51    fn load(&self) -> u8 {
52        self.0.load(Acquire)
53    }
54
55    /// Do CAS and return action result.
56    fn fetch_update_action<F, O>(&self, mut f: F) -> O
57    where
58        F: FnMut(u8) -> (O, Option<u8>),
59    {
60        let mut curr = self.0.load(Acquire);
61        loop {
62            let (output, next) = f(curr);
63            let next = match next {
64                Some(next) => next,
65                None => return output,
66            };
67
68            match self.0.compare_exchange(curr, next, AcqRel, Acquire) {
69                Ok(_) => return output,
70                Err(actual) => curr = actual,
71            }
72        }
73    }
74}
75
76impl<T, A> SlotInner<T, A> {
77    #[inline]
78    const fn new() -> Self {
79        Self {
80            state: State(AtomicU8::new(0)),
81            data: MaybeUninit::uninit(),
82            attachment: None,
83            waker: Mutex::new(None),
84        }
85    }
86
87    #[inline]
88    fn read(&self) -> Option<T> {
89        let mut data = MaybeUninit::uninit();
90        let copied = self.state.fetch_update_action(|curr| {
91            if curr & 0b101 == 0b101 {
92                // data has been written and writer has been dropped(data has been fully written)
93                unsafe { data = MaybeUninit::new(self.data.as_ptr().read()) };
94                // unset the written bit
95                (true, Some(curr & 0b011))
96            } else {
97                (false, None)
98            }
99        });
100
101        if copied {
102            Some(unsafe { data.assume_init() })
103        } else {
104            None
105        }
106    }
107
108    #[inline]
109    fn write(&mut self, data: T) -> Option<T> {
110        let succ = self.state.fetch_update_action(|curr| {
111            if curr & 0b100 != 0 {
112                // data has been written or another writer has got this bit(but this would not happen in fact)
113                (false, None)
114            } else {
115                // we got this bit
116                (true, Some(0b100 | curr))
117            }
118        });
119
120        if !succ {
121            return Some(data);
122        }
123
124        unsafe { self.data.as_mut_ptr().write(data) };
125        None
126    }
127}
128
129#[repr(transparent)]
130pub struct SlotReader<T, A = ()>(NonNull<SlotInner<T, A>>);
131unsafe impl<T: Send, A: Send> Send for SlotReader<T, A> {}
132unsafe impl<T: Send, A: Send> Sync for SlotReader<T, A> {}
133
134impl<T, A> SlotReader<T, A> {
135    #[inline]
136    pub fn read(&self) -> Option<T> {
137        unsafe { self.0.as_ref() }.read()
138    }
139
140    /// # Safety
141    /// Must be read after attachment write.
142    #[inline]
143    pub unsafe fn read_with_attachment(&mut self) -> Option<(T, Option<A>)> {
144        let inner = unsafe { self.0.as_mut() };
145        inner.read().map(|res| (res, inner.attachment.take()))
146    }
147
148    #[inline]
149    pub(crate) fn set_waker(&mut self, waker: &Waker) {
150        unsafe {
151            let mut waker_locked = self.0.as_mut().waker.lock().unwrap();
152            match waker_locked.as_mut() {
153                None => *waker_locked = Some(waker.clone()),
154                Some(w) => w.clone_from(waker),
155            }
156        }
157    }
158}
159
160impl<T, A> Drop for SlotReader<T, A> {
161    #[inline]
162    fn drop(&mut self) {
163        unsafe {
164            if self
165                .0
166                .as_ref()
167                .state
168                .fetch_update_action(|curr| (curr & 0b001 != 0, Some(0b010 | curr)))
169            {
170                drop(Box::from_raw(self.0.as_ptr()));
171            }
172        }
173    }
174}
175
176#[repr(transparent)]
177pub struct SlotWriter<T, A = ()>(NonNull<SlotInner<T, A>>);
178unsafe impl<T: Send, A: Send> Send for SlotWriter<T, A> {}
179unsafe impl<T: Send, A: Send> Sync for SlotWriter<T, A> {}
180
181impl<T, A> SlotWriter<T, A> {
182    #[inline]
183    pub fn write(mut self, data: T) {
184        if unsafe { self.0.as_mut() }.write(data).is_none() {
185            let waker = unsafe { self.0.as_ref().waker.lock().unwrap().take() };
186            if let Some(waker) = waker {
187                drop(self);
188                waker.wake();
189            }
190        }
191    }
192
193    #[inline]
194    pub fn into_ptr(self) -> *const () {
195        let ptr = self.0.as_ptr() as *const ();
196        std::mem::forget(self);
197        ptr
198    }
199
200    /// # Safety
201    /// Pointer must be a valid *SlotInner<T>.
202    #[inline]
203    pub unsafe fn from_ptr(ptr: *const ()) -> Self {
204        Self(NonNull::new_unchecked(ptr as _))
205    }
206
207    #[inline]
208    pub(crate) fn attach(&mut self, attachment: A) -> &mut A {
209        unsafe { self.0.as_mut() }.attachment.insert(attachment)
210    }
211
212    #[inline]
213    pub(crate) fn set_waker(&mut self, waker: Waker) {
214        unsafe { *self.0.as_mut().waker.lock().unwrap() = Some(waker) };
215    }
216}
217
218impl<T, A> Drop for SlotWriter<T, A> {
219    #[inline]
220    fn drop(&mut self) {
221        unsafe {
222            if self
223                .0
224                .as_ref()
225                .state
226                .fetch_update_action(|curr| (curr & 0b010 != 0, Some(0b001 | curr)))
227            {
228                drop(Box::from_raw(self.0.as_ptr()));
229            }
230        }
231    }
232}