either_slot/
lib.rs

1#![doc = include_str!("../README.md")]
2#![no_std]
3#![feature(allocator_api)]
4#![feature(trusted_len)]
5#![cfg_attr(loom, feature(alloc_layout_extra))]
6#![cfg_attr(test, feature(assert_matches))]
7
8#[cfg_attr(not(loom), path = "include_core.rs")]
9#[cfg_attr(loom, path = "include_loom.rs")]
10mod include;
11
12pub mod array;
13pub mod tuple;
14
15use self::include::*;
16pub use self::{
17    array::{array, vec},
18    tuple::tuple,
19};
20
21extern crate alloc;
22
23#[cfg(test)]
24extern crate std;
25
26union Place<A, B> {
27    uninit: (),
28    a: ManuallyDrop<A>,
29    b: ManuallyDrop<B>,
30}
31
32const INIT: u8 = 0;
33const WRITING: u8 = 1;
34const HAS_A: u8 = 2;
35const HAS_B: u8 = 3;
36const DONE: u8 = 4;
37
38struct Inner<A, B> {
39    state: AtomicU8,
40    place: UnsafeCell<Place<A, B>>,
41}
42
43impl<A, B> Inner<A, B> {
44    const LAYOUT: Layout = Layout::new::<Self>();
45
46    fn new() -> NonNull<Self> {
47        let memory = match Global.allocate(Self::LAYOUT) {
48            Ok(memory) => memory.cast::<Self>(),
49            Err(_) => handle_alloc_error(Self::LAYOUT),
50        };
51        let value = Self {
52            state: AtomicU8::new(INIT),
53            place: UnsafeCell::new(Place { uninit: () }),
54        };
55        unsafe { memory.as_ptr().write(value) }
56        memory
57    }
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
61pub enum SendError<P, Q> {
62    Received(P, Q),
63    Disconnected(P),
64}
65
66#[derive(Debug)]
67pub struct ASender<A, B>(NonNull<Inner<A, B>>);
68
69#[derive(Debug)]
70pub struct BSender<A, B>(NonNull<Inner<A, B>>);
71
72unsafe impl<A: Send, B: Send> Send for ASender<A, B> {}
73unsafe impl<A: Send, B: Send> Send for BSender<A, B> {}
74
75impl<A, B> ASender<A, B> {
76    const LAYOUT: Layout = Inner::<A, B>::LAYOUT;
77
78    pub fn send(self, a: A) -> Result<(), SendError<A, B>> {
79        let inner = unsafe { self.0.as_ref() };
80        loop {
81            match inner
82                .state
83                .compare_exchange(INIT, WRITING, Acquire, Acquire)
84            {
85                Ok(_) => {
86                    let a = ManuallyDrop::new(a);
87                    unsafe { inner.place.with_mut(|ptr| ptr.write(Place { a })) };
88                    inner.state.store(HAS_A, Release);
89
90                    mem::forget(self);
91                    break Ok(());
92                }
93                Err(state) => match state {
94                    WRITING => hint::spin_loop(),
95                    HAS_B => {
96                        let b = unsafe { inner.place.with_mut(|ptr| ptr.read().b) };
97                        inner.state.store(DONE, Release);
98
99                        break Err(SendError::Received(a, ManuallyDrop::into_inner(b)));
100                    }
101                    DONE => break Err(SendError::Disconnected(a)),
102                    _ => unreachable!(),
103                },
104            }
105        }
106    }
107}
108
109impl<A, B> Drop for ASender<A, B> {
110    fn drop(&mut self) {
111        let inner = unsafe { self.0.as_ref() };
112        loop {
113            let state = inner.state.load(Acquire);
114            if state != WRITING {
115                match state {
116                    INIT => {
117                        if inner
118                            .state
119                            .compare_exchange_weak(INIT, DONE, AcqRel, Acquire)
120                            .is_ok()
121                        {
122                            break;
123                        }
124                    }
125                    HAS_B => inner
126                        .place
127                        .with_mut(|ptr| unsafe { ManuallyDrop::drop(&mut (*ptr).b) }),
128                    DONE => {}
129                    _ => unreachable!(),
130                }
131                unsafe { Global.deallocate(self.0.cast(), Self::LAYOUT) };
132                break;
133            }
134            hint::spin_loop();
135        }
136    }
137}
138
139impl<A, B> BSender<A, B> {
140    const LAYOUT: Layout = Inner::<A, B>::LAYOUT;
141
142    pub fn send(self, b: B) -> Result<(), SendError<B, A>> {
143        let inner = unsafe { self.0.as_ref() };
144        loop {
145            match inner
146                .state
147                .compare_exchange(INIT, WRITING, Acquire, Acquire)
148            {
149                Ok(_) => {
150                    let b = ManuallyDrop::new(b);
151                    unsafe { inner.place.with_mut(|ptr| ptr.write(Place { b })) };
152                    inner.state.store(HAS_B, Release);
153
154                    mem::forget(self);
155                    break Ok(());
156                }
157                Err(state) => match state {
158                    WRITING => hint::spin_loop(),
159                    HAS_A => {
160                        let a = unsafe { inner.place.with_mut(|ptr| ptr.read().a) };
161                        inner.state.store(DONE, Release);
162
163                        break Err(SendError::Received(b, ManuallyDrop::into_inner(a)));
164                    }
165                    DONE => break Err(SendError::Disconnected(b)),
166                    _ => unreachable!(),
167                },
168            }
169        }
170    }
171}
172
173impl<A, B> Drop for BSender<A, B> {
174    fn drop(&mut self) {
175        let inner = unsafe { self.0.as_ref() };
176        loop {
177            let state = inner.state.load(Acquire);
178            if state != WRITING {
179                match state {
180                    INIT => {
181                        if inner
182                            .state
183                            .compare_exchange_weak(INIT, DONE, AcqRel, Acquire)
184                            .is_ok()
185                        {
186                            break;
187                        }
188                    }
189                    HAS_A => inner
190                        .place
191                        .with_mut(|ptr| unsafe { ManuallyDrop::drop(&mut (*ptr).a) }),
192                    DONE => {}
193                    _ => unreachable!(),
194                }
195                unsafe { Global.deallocate(self.0.cast(), Self::LAYOUT) };
196                break;
197            }
198            hint::spin_loop();
199        }
200    }
201}
202
203pub fn either<A, B>() -> (ASender<A, B>, BSender<A, B>) {
204    let inner = Inner::new();
205    (ASender(inner), BSender(inner))
206}
207
208#[cfg(test)]
209mod tests {
210    use std::assert_matches::assert_matches;
211    #[cfg(not(loom))]
212    use std::thread;
213
214    #[cfg(loom)]
215    use loom::thread;
216
217    use crate::{either, SendError};
218
219    #[cfg(not(loom))]
220    #[test]
221    fn basic() {
222        let (a, b) = either();
223        a.send(1).unwrap();
224        assert_eq!(b.send('x'), Err(crate::SendError::Received('x', 1)));
225
226        let (a, b) = either::<_, ()>();
227        drop(b);
228        assert_eq!(a.send(1), Err(SendError::Disconnected(1)));
229
230        let _ = either::<i32, u8>();
231    }
232
233    #[test]
234    fn send() {
235        fn inner() {
236            let (a, b) = either();
237            let t = thread::spawn(move || a.send(1));
238            let r1 = b.send('x');
239            let r2 = t.join().unwrap();
240            assert_matches!(
241                (r1, r2),
242                (Ok(()), Err(SendError::Received(1, 'x')))
243                    | (Err(SendError::Received('x', 1)), Ok(()))
244            )
245        }
246        #[cfg(not(loom))]
247        inner();
248        #[cfg(loom)]
249        loom::model(|| inner());
250    }
251
252    #[test]
253    fn drop_either() {
254        fn inner() {
255            let (a, b) = either::<i32, _>();
256            let t = thread::spawn(move || drop(a));
257            assert_matches!(b.send(1), Err(SendError::Disconnected(1)) | Ok(()));
258            t.join().unwrap();
259        }
260        #[cfg(not(loom))]
261        inner();
262        #[cfg(loom)]
263        loom::model(|| inner());
264    }
265
266    #[test]
267    fn drop_both() {
268        fn inner() {
269            let (a, b) = either::<i32, u8>();
270            let t = thread::spawn(move || drop(a));
271            drop(b);
272            t.join().unwrap();
273        }
274        #[cfg(not(loom))]
275        inner();
276        #[cfg(loom)]
277        loom::model(|| inner());
278    }
279}