Skip to main content

compio_send_wrapper/
lib.rs

1// Copyright 2017 Thomas Keh.
2// Copyright 2024 compio-rs
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! This [Rust] library implements a wrapper type called [`SendWrapper`] which
11//! allows you to move around non-[`Send`] types between threads, as long as you
12//! access the contained value only from within the original thread. You also
13//! have to make sure that the wrapper is dropped from within the original
14//! thread. If any of these constraints is violated, a panic occurs.
15//! [`SendWrapper<T>`] implements [`Send`] and [`Sync`] for any type `T`.
16//!
17//! # Examples
18//!
19//! ```rust
20//! use std::{rc::Rc, sync::mpsc::channel, thread};
21//!
22//! use compio_send_wrapper::SendWrapper;
23//!
24//! // Rc is a non-Send type.
25//! let value = Rc::new(42);
26//!
27//! // We now wrap the value with `SendWrapper` (value is moved inside).
28//! let wrapped_value = SendWrapper::new(value);
29//!
30//! // A channel allows us to move the wrapped value between threads.
31//! let (sender, receiver) = channel();
32//!
33//! let t = thread::spawn(move || {
34//!     // This would panic (because of accessing in the wrong thread):
35//!     // let value = wrapped_value.get().unwrap();
36//!
37//!     // Move SendWrapper back to main thread, so it can be dropped from there.
38//!     // If you leave this out the thread will panic because of dropping from wrong thread.
39//!     sender.send(wrapped_value).unwrap();
40//! });
41//!
42//! let wrapped_value = receiver.recv().unwrap();
43//!
44//! // Now you can use the value again.
45//! let value = wrapped_value.get().unwrap();
46//!
47//! let mut wrapped_value = wrapped_value;
48//!
49//! // You can also get a mutable reference to the value.
50//! let value = wrapped_value.get_mut().unwrap();
51//! ```
52//!
53//! # Features
54//!
55//! This crate exposes several optional features:
56//!
57//! - `futures`: Enables [`Future`] and [`Stream`] implementations for
58//!   [`SendWrapper`].
59//! - `current_thread_id`: Uses the unstable [`std::thread::current_id`] API (on
60//!   nightly Rust) to track the originating thread more efficiently.
61//! - `nightly`: Enables nightly-only, experimental functionality used by this
62//!   crate (including support for `current_thread_id` as configured in
63//!   `Cargo.toml`).
64//!
65//! You can enable them in `Cargo.toml` like so:
66//!
67//! ```toml
68//! compio-send-wrapper = { version = "...", features = ["futures"] }
69//! # or, for example:
70//! # compio-send-wrapper = { version = "...", features = ["futures", "current_thread_id"] }
71//! ```
72//!
73//! # License
74//!
75//! `compio-send-wrapper` is distributed under the terms of both the MIT license
76//! and the Apache License (Version 2.0).
77//!
78//! See LICENSE-APACHE.txt, and LICENSE-MIT.txt for details.
79//!
80//! [Rust]: https://www.rust-lang.org
81//! [`Future`]: std::future::Future
82//! [`Stream`]: futures_core::Stream
83// To build docs locally use `RUSTDOCFLAGS="--cfg docsrs" cargo doc --open --all-features`
84#![cfg_attr(docsrs, feature(doc_cfg))]
85#![cfg_attr(
86    all(not(loom), feature = "current_thread_id"),
87    feature(current_thread_id)
88)]
89#![warn(missing_docs)]
90
91#[cfg(feature = "futures")]
92#[cfg_attr(docsrs, doc(cfg(feature = "futures")))]
93mod futures;
94
95use std::{
96    fmt,
97    mem::{self, ManuallyDrop},
98    pin::Pin,
99};
100
101cfg_if::cfg_if! {
102    if #[cfg(any(loom, not(feature = "current_thread_id")))] {
103        #[cfg(loom)]
104        use loom::{thread_local, cell::Cell, thread::{self, ThreadId}};
105        #[cfg(not(loom))]
106        use std::{thread_local, cell::Cell, thread::{self, ThreadId}};
107
108        thread_local! {
109            static THREAD_ID: Cell<ThreadId> = Cell::new(thread::current().id());
110        }
111
112        /// Get the current [`ThreadId`].
113        pub(crate) fn current_id() -> ThreadId {
114            THREAD_ID.with(|id| id.get())
115        }
116    } else {
117        use std::thread::{self, current_id, ThreadId};
118    }
119}
120
121/// A wrapper which allows you to move around non-[`Send`]-types between
122/// threads, as long as you access the contained value only from within the
123/// original thread and make sure that it is dropped from within the original
124/// thread.
125pub struct SendWrapper<T> {
126    data: ManuallyDrop<T>,
127    thread_id: ThreadId,
128}
129
130impl<T> SendWrapper<T> {
131    /// Create a `SendWrapper<T>` wrapper around a value of type `T`.
132    /// The wrapper takes ownership of the value.
133    #[inline]
134    pub fn new(data: T) -> SendWrapper<T> {
135        SendWrapper {
136            data: ManuallyDrop::new(data),
137            thread_id: current_id(),
138        }
139    }
140
141    /// Returns `true` if the value can be safely accessed from within the
142    /// current thread.
143    #[inline]
144    pub fn valid(&self) -> bool {
145        self.thread_id == current_id()
146    }
147
148    /// Takes the value out of the `SendWrapper<T>`.
149    ///
150    /// # Safety
151    ///
152    /// The caller should be in the same thread as the creator.
153    pub unsafe fn take_unchecked(self) -> T {
154        // Prevent drop() from being called, as it would drop `self.data` twice
155        let mut this = ManuallyDrop::new(self);
156
157        // Safety:
158        // - The caller of this unsafe function guarantees that it's valid to access `T`
159        //   from the current thread (the safe `take` method enforces this precondition
160        //   before calling `take_unchecked`).
161        // - We only move out from `self.data` here and in drop, so `self.data` is
162        //   present
163        unsafe { ManuallyDrop::take(&mut this.data) }
164    }
165
166    /// Takes the value out of the `SendWrapper<T>`.
167    ///
168    /// # Panics
169    ///
170    /// Panics if it is called from a different thread than the one the
171    /// `SendWrapper<T>` instance has been created with.
172    #[track_caller]
173    pub fn take(self) -> T {
174        if self.valid() {
175            // SAFETY: the same thread as the creator
176            unsafe { self.take_unchecked() }
177        } else {
178            invalid_deref()
179        }
180    }
181
182    /// Returns a reference to the contained value.
183    ///
184    /// # Safety
185    ///
186    /// The caller should be in the same thread as the creator.
187    #[inline]
188    pub unsafe fn get_unchecked(&self) -> &T {
189        &self.data
190    }
191
192    /// Returns a mutable reference to the contained value.
193    ///
194    /// # Safety
195    ///
196    /// The caller should be in the same thread as the creator.
197    #[inline]
198    pub unsafe fn get_unchecked_mut(&mut self) -> &mut T {
199        &mut self.data
200    }
201
202    /// Returns a pinned reference to the contained value.
203    ///
204    /// # Safety
205    ///
206    /// The caller should be in the same thread as the creator.
207    #[inline]
208    pub unsafe fn get_unchecked_pinned(self: Pin<&Self>) -> Pin<&T> {
209        // SAFETY: as long as `SendWrapper` is pinned, the inner data is pinned too.
210        unsafe { self.map_unchecked(|s| &*s.data) }
211    }
212
213    /// Returns a pinned mutable reference to the contained value.
214    ///
215    /// # Safety
216    ///
217    /// The caller should be in the same thread as the creator.
218    #[inline]
219    pub unsafe fn get_unchecked_pinned_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
220        // SAFETY: as long as `SendWrapper` is pinned, the inner data is pinned too.
221        unsafe { self.map_unchecked_mut(|s| &mut *s.data) }
222    }
223
224    /// Returns a reference to the contained value, if valid.
225    #[inline]
226    pub fn get(&self) -> Option<&T> {
227        if self.valid() { Some(&self.data) } else { None }
228    }
229
230    /// Returns a mutable reference to the contained value, if valid.
231    #[inline]
232    pub fn get_mut(&mut self) -> Option<&mut T> {
233        if self.valid() {
234            Some(&mut self.data)
235        } else {
236            None
237        }
238    }
239
240    /// Returns a pinned reference to the contained value, if valid.
241    #[inline]
242    pub fn get_pinned(self: Pin<&Self>) -> Option<Pin<&T>> {
243        if self.valid() {
244            // SAFETY: the same thread as the creator
245            Some(unsafe { self.get_unchecked_pinned() })
246        } else {
247            None
248        }
249    }
250
251    /// Returns a pinned mutable reference to the contained value, if valid.
252    #[inline]
253    pub fn get_pinned_mut(self: Pin<&mut Self>) -> Option<Pin<&mut T>> {
254        if self.valid() {
255            // SAFETY: the same thread as the creator
256            Some(unsafe { self.get_unchecked_pinned_mut() })
257        } else {
258            None
259        }
260    }
261
262    /// Returns a tracker that can be used to check if the current thread is
263    /// the same as the creator thread.
264    #[inline]
265    pub fn tracker(&self) -> SendWrapper<()> {
266        SendWrapper {
267            data: ManuallyDrop::new(()),
268            thread_id: self.thread_id,
269        }
270    }
271}
272
273unsafe impl<T> Send for SendWrapper<T> {}
274unsafe impl<T> Sync for SendWrapper<T> {}
275
276impl<T> Drop for SendWrapper<T> {
277    /// Drops the contained value.
278    ///
279    /// # Panics
280    ///
281    /// Dropping panics if it is done from a different thread than the one the
282    /// `SendWrapper<T>` instance has been created with.
283    ///
284    /// Exceptions:
285    /// - There is no extra panic if the thread is already panicking/unwinding.
286    ///   This is because otherwise there would be double panics (usually
287    ///   resulting in an abort) when dereferencing from a wrong thread.
288    /// - If `T` has a trivial drop ([`needs_drop::<T>()`] is false) then this
289    ///   method never panics.
290    ///
291    /// [`needs_drop::<T>()`]: std::mem::needs_drop
292    #[track_caller]
293    fn drop(&mut self) {
294        // If the drop is trivial (`needs_drop` = false), then dropping `T` can't access
295        // it and so it can be safely dropped on any thread.
296        if !mem::needs_drop::<T>() || self.valid() {
297            unsafe {
298                // Drop the inner value
299                //
300                // SAFETY:
301                // - We've just checked that it's valid to drop `T` on this thread
302                // - We only move out from `self.data` here and in drop, so `self.data` is
303                //   present
304                ManuallyDrop::drop(&mut self.data);
305            }
306        } else {
307            invalid_drop()
308        }
309    }
310}
311
312impl<T: fmt::Debug> fmt::Debug for SendWrapper<T> {
313    /// Formats the value using the given formatter.
314    ///
315    /// If the `SendWrapper<T>` is formatted from a different thread than the
316    /// one it was created on, the `data` field is shown as `"<invalid>"`
317    /// instead of causing a panic.
318    #[track_caller]
319    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
320        let mut f = f.debug_struct("SendWrapper");
321        if let Some(data) = self.get() {
322            f.field("data", data);
323        } else {
324            f.field("data", &"<invalid>");
325        }
326        f.field("thread_id", &self.thread_id).finish()
327    }
328}
329
330impl<T: Clone> Clone for SendWrapper<T> {
331    /// Returns a copy of the value.
332    ///
333    /// # Panics
334    ///
335    /// Cloning panics if it is done from a different thread than the one
336    /// the `SendWrapper<T>` instance has been created with.
337    #[track_caller]
338    fn clone(&self) -> Self {
339        Self::new(self.get().unwrap_or_else(|| invalid_deref()).clone())
340    }
341}
342
343#[cold]
344#[inline(never)]
345#[track_caller]
346fn invalid_deref() -> ! {
347    const DEREF_ERROR: &str = "Accessed SendWrapper<T> variable from a thread different to the \
348                               one it has been created with.";
349
350    panic!("{}", DEREF_ERROR)
351}
352
353#[cold]
354#[inline(never)]
355#[track_caller]
356#[cfg(feature = "futures")]
357fn invalid_poll() -> ! {
358    const POLL_ERROR: &str = "Polling SendWrapper<T> variable from a thread different to the one \
359                              it has been created with.";
360
361    panic!("{}", POLL_ERROR)
362}
363
364#[cold]
365#[inline(never)]
366#[track_caller]
367fn invalid_drop() {
368    const DROP_ERROR: &str = "Dropped SendWrapper<T> variable from a thread different to the one \
369                              it has been created with.";
370
371    if !thread::panicking() {
372        // panic because of dropping from wrong thread
373        // only do this while not unwinding (could be caused by deref from wrong thread)
374        panic!("{}", DROP_ERROR)
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use std::{
381        pin::Pin,
382        rc::Rc,
383        sync::{Arc, mpsc::channel},
384        thread,
385    };
386
387    use super::SendWrapper;
388
389    #[test]
390    fn get_and_get_mut_on_creator_thread_and_pinned_variants() {
391        let mut wrapper = SendWrapper::new(1_i32);
392
393        // On the creator thread, the plain accessors should return Some.
394        let r = wrapper.get();
395        assert!(r.is_some());
396        assert_eq!(*r.unwrap(), 1);
397
398        let r_mut = wrapper.get_mut();
399        assert!(r_mut.is_some());
400        *r_mut.unwrap() = 2;
401
402        // The change via get_mut should be visible via get as well.
403        let r_after = wrapper.get();
404        assert!(r_after.is_some());
405        assert_eq!(*r_after.unwrap(), 2);
406
407        // Pinned shared reference should also succeed on the creator thread.
408        let pinned = Pin::new(&wrapper);
409        let pinned_ref = pinned.get_pinned();
410        assert!(pinned_ref.is_some());
411        assert_eq!(*pinned_ref.unwrap(), 2);
412
413        // Pinned mutable reference should succeed and allow mutation.
414        let mut wrapper2 = SendWrapper::new(10_i32);
415        let pinned_mut = Pin::new(&mut wrapper2);
416        let pinned_mut_ref = pinned_mut.get_pinned_mut();
417        assert!(pinned_mut_ref.is_some());
418        *pinned_mut_ref.unwrap() = 11;
419
420        let after_mut = wrapper2.get();
421        assert!(after_mut.is_some());
422        assert_eq!(*after_mut.unwrap(), 11);
423    }
424
425    #[test]
426    fn accessors_return_none_on_non_creator_thread() {
427        let mut wrapper = SendWrapper::new(123_i32);
428
429        // Move the wrapper to another thread; that thread is not the creator.
430        let handle = thread::spawn(move || {
431            // Plain accessors should return None on non-creator thread.
432            assert!(wrapper.get().is_none());
433            assert!(wrapper.get_mut().is_none());
434
435            // Pinned accessors should also return None on non-creator thread.
436            let pinned = Pin::new(&wrapper);
437            assert!(pinned.get_pinned().is_none());
438
439            let mut wrapper = wrapper;
440            let pinned_mut = Pin::new(&mut wrapper);
441            assert!(pinned_mut.get_pinned_mut().is_none());
442        });
443
444        handle.join().unwrap();
445    }
446
447    #[test]
448    fn test_valid() {
449        let (sender, receiver) = channel();
450        let w = SendWrapper::new(Rc::new(42));
451        assert!(w.valid());
452        let t = thread::spawn(move || {
453            // move SendWrapper back to main thread, so it can be dropped from there
454            sender.send(w).unwrap();
455        });
456        let w2 = receiver.recv().unwrap();
457        assert!(w2.valid());
458        assert!(t.join().is_ok());
459    }
460
461    #[test]
462    fn test_invalid() {
463        let w = SendWrapper::new(Rc::new(42));
464        let t = thread::spawn(move || {
465            assert!(!w.valid());
466            w
467        });
468        let join_result = t.join();
469        assert!(join_result.is_ok());
470    }
471
472    #[test]
473    fn test_drop_panic() {
474        let w = SendWrapper::new(Rc::new(42));
475        let t = thread::spawn(move || {
476            drop(w);
477        });
478        let join_result = t.join();
479        assert!(join_result.is_err());
480    }
481
482    #[test]
483    fn test_take() {
484        let w = SendWrapper::new(Rc::new(42));
485        let inner: Rc<usize> = w.take();
486        assert_eq!(42, *inner);
487    }
488
489    #[test]
490    fn test_take_panic() {
491        let w = SendWrapper::new(Rc::new(42));
492        let t = thread::spawn(move || {
493            let _ = w.take();
494        });
495        assert!(t.join().is_err());
496    }
497    #[test]
498    fn test_sync() {
499        // Arc<T> can only be sent to another thread if T Sync
500        let arc = Arc::new(SendWrapper::new(42));
501        thread::spawn(move || {
502            let _ = arc;
503        });
504    }
505
506    #[test]
507    fn test_debug() {
508        let w = SendWrapper::new(Rc::new(42));
509        let info = format!("{:?}", w);
510        assert!(info.contains("SendWrapper {"));
511        assert!(info.contains("data: 42,"));
512        assert!(info.contains("thread_id: ThreadId("));
513    }
514
515    #[test]
516    fn test_debug_invalid() {
517        let w = SendWrapper::new(Rc::new(42));
518        let t = thread::spawn(move || {
519            let info = format!("{:?}", w);
520            assert!(info.contains("SendWrapper {"));
521            assert!(info.contains("data: \"<invalid>\","));
522            assert!(info.contains("thread_id: ThreadId("));
523            w
524        });
525        assert!(t.join().is_ok());
526    }
527
528    #[test]
529    fn test_clone() {
530        let w1 = SendWrapper::new(Rc::new(42));
531        let w2 = w1.clone();
532        assert_eq!(format!("{:?}", w1), format!("{:?}", w2));
533    }
534
535    #[test]
536    fn test_clone_panic() {
537        let w = SendWrapper::new(Rc::new(42));
538        let t = thread::spawn(move || {
539            let _ = w.clone();
540        });
541        assert!(t.join().is_err());
542    }
543}