Skip to main content

compio_send_wrapper/
lib.rs

1// Copyright 2017 Thomas Keh.
2// Copyright 2024 compio-rs
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! This [Rust] library implements a wrapper type called [`SendWrapper`] which
11//! allows you to move around non-[`Send`] types between threads, as long as you
12//! access the contained value only from within the original thread. You also
13//! have to make sure that the wrapper is dropped from within the original
14//! thread. If any of these constraints is violated, a panic occurs.
15//! [`SendWrapper<T>`] implements [`Send`] and [`Sync`] for any type `T`.
16//!
17//! # Examples
18//!
19//! ```rust
20//! use std::{rc::Rc, sync::mpsc::channel, thread};
21//!
22//! use compio_send_wrapper::SendWrapper;
23//!
24//! // Rc is a non-Send type.
25//! let value = Rc::new(42);
26//!
27//! // We now wrap the value with `SendWrapper` (value is moved inside).
28//! let wrapped_value = SendWrapper::new(value);
29//!
30//! // A channel allows us to move the wrapped value between threads.
31//! let (sender, receiver) = channel();
32//!
33//! let t = thread::spawn(move || {
34//!     // This would panic (because of accessing in the wrong thread):
35//!     // let value = wrapped_value.get().unwrap();
36//!
37//!     // Move SendWrapper back to main thread, so it can be dropped from there.
38//!     // If you leave this out the thread will panic because of dropping from wrong thread.
39//!     sender.send(wrapped_value).unwrap();
40//! });
41//!
42//! let wrapped_value = receiver.recv().unwrap();
43//!
44//! // Now you can use the value again.
45//! let value = wrapped_value.get().unwrap();
46//!
47//! let mut wrapped_value = wrapped_value;
48//!
49//! // You can also get a mutable reference to the value.
50//! let value = wrapped_value.get_mut().unwrap();
51//! ```
52//!
53//! # Features
54//!
55//! This crate exposes several optional features:
56//!
57//! - `futures`: Enables [`Future`] and [`Stream`] implementations for
58//!   [`SendWrapper`].
59//! - `current_thread_id`: Uses the unstable [`std::thread::current_id`] API (on
60//!   nightly Rust) to track the originating thread more efficiently.
61//! - `nightly`: Enables nightly-only, experimental functionality used by this
62//!   crate (including support for `current_thread_id` as configured in
63//!   `Cargo.toml`).
64//!
65//! You can enable them in `Cargo.toml` like so:
66//!
67//! ```toml
68//! compio-send-wrapper = { version = "...", features = ["futures"] }
69//! # or, for example:
70//! # compio-send-wrapper = { version = "...", features = ["futures", "current_thread_id"] }
71//! ```
72//!
73//! # License
74//!
75//! `compio-send-wrapper` is distributed under the terms of both the MIT license
76//! and the Apache License (Version 2.0).
77//!
78//! See LICENSE-APACHE.txt, and LICENSE-MIT.txt for details.
79//!
80//! [Rust]: https://www.rust-lang.org
81//! [`Future`]: std::future::Future
82//! [`Stream`]: futures_core::Stream
83// To build docs locally use `RUSTDOCFLAGS="--cfg docsrs" cargo doc --open --all-features`
84#![cfg_attr(docsrs, feature(doc_cfg))]
85#![cfg_attr(feature = "current_thread_id", feature(current_thread_id))]
86#![warn(missing_docs)]
87
88#[cfg(feature = "futures")]
89#[cfg_attr(docsrs, doc(cfg(feature = "futures")))]
90mod futures;
91
92#[cfg(feature = "current_thread_id")]
93use std::thread::current_id;
94use std::{
95    fmt,
96    mem::{self, ManuallyDrop},
97    pin::Pin,
98    thread::{self, ThreadId},
99};
100
101#[cfg(not(feature = "current_thread_id"))]
102mod imp {
103    use std::{
104        cell::Cell,
105        thread::{self, ThreadId},
106    };
107    thread_local! {
108        static THREAD_ID: Cell<ThreadId> = Cell::new(thread::current().id());
109    }
110
111    pub fn current_id() -> ThreadId {
112        THREAD_ID.get()
113    }
114}
115
116#[cfg(not(feature = "current_thread_id"))]
117use imp::current_id;
118
119/// A wrapper which allows you to move around non-[`Send`]-types between
120/// threads, as long as you access the contained value only from within the
121/// original thread and make sure that it is dropped from within the original
122/// thread.
123pub struct SendWrapper<T> {
124    data: ManuallyDrop<T>,
125    thread_id: ThreadId,
126}
127
128impl<T> SendWrapper<T> {
129    /// Create a `SendWrapper<T>` wrapper around a value of type `T`.
130    /// The wrapper takes ownership of the value.
131    #[inline]
132    pub fn new(data: T) -> SendWrapper<T> {
133        SendWrapper {
134            data: ManuallyDrop::new(data),
135            thread_id: current_id(),
136        }
137    }
138
139    /// Returns `true` if the value can be safely accessed from within the
140    /// current thread.
141    #[inline]
142    pub fn valid(&self) -> bool {
143        self.thread_id == current_id()
144    }
145
146    /// Takes the value out of the `SendWrapper<T>`.
147    ///
148    /// # Safety
149    ///
150    /// The caller should be in the same thread as the creator.
151    pub unsafe fn take_unchecked(self) -> T {
152        // Prevent drop() from being called, as it would drop `self.data` twice
153        let mut this = ManuallyDrop::new(self);
154
155        // Safety:
156        // - The caller of this unsafe function guarantees that it's valid to access `T`
157        //   from the current thread (the safe `take` method enforces this precondition
158        //   before calling `take_unchecked`).
159        // - We only move out from `self.data` here and in drop, so `self.data` is
160        //   present
161        unsafe { ManuallyDrop::take(&mut this.data) }
162    }
163
164    /// Takes the value out of the `SendWrapper<T>`.
165    ///
166    /// # Panics
167    ///
168    /// Panics if it is called from a different thread than the one the
169    /// `SendWrapper<T>` instance has been created with.
170    #[track_caller]
171    pub fn take(self) -> T {
172        if self.valid() {
173            // SAFETY: the same thread as the creator
174            unsafe { self.take_unchecked() }
175        } else {
176            invalid_deref()
177        }
178    }
179
180    /// Returns a reference to the contained value.
181    ///
182    /// # Safety
183    ///
184    /// The caller should be in the same thread as the creator.
185    #[inline]
186    pub unsafe fn get_unchecked(&self) -> &T {
187        &self.data
188    }
189
190    /// Returns a mutable reference to the contained value.
191    ///
192    /// # Safety
193    ///
194    /// The caller should be in the same thread as the creator.
195    #[inline]
196    pub unsafe fn get_unchecked_mut(&mut self) -> &mut T {
197        &mut self.data
198    }
199
200    /// Returns a pinned reference to the contained value.
201    ///
202    /// # Safety
203    ///
204    /// The caller should be in the same thread as the creator.
205    #[inline]
206    pub unsafe fn get_unchecked_pinned(self: Pin<&Self>) -> Pin<&T> {
207        // SAFETY: as long as `SendWrapper` is pinned, the inner data is pinned too.
208        unsafe { self.map_unchecked(|s| &*s.data) }
209    }
210
211    /// Returns a pinned mutable reference to the contained value.
212    ///
213    /// # Safety
214    ///
215    /// The caller should be in the same thread as the creator.
216    #[inline]
217    pub unsafe fn get_unchecked_pinned_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
218        // SAFETY: as long as `SendWrapper` is pinned, the inner data is pinned too.
219        unsafe { self.map_unchecked_mut(|s| &mut *s.data) }
220    }
221
222    /// Returns a reference to the contained value, if valid.
223    #[inline]
224    pub fn get(&self) -> Option<&T> {
225        if self.valid() { Some(&self.data) } else { None }
226    }
227
228    /// Returns a mutable reference to the contained value, if valid.
229    #[inline]
230    pub fn get_mut(&mut self) -> Option<&mut T> {
231        if self.valid() {
232            Some(&mut self.data)
233        } else {
234            None
235        }
236    }
237
238    /// Returns a pinned reference to the contained value, if valid.
239    #[inline]
240    pub fn get_pinned(self: Pin<&Self>) -> Option<Pin<&T>> {
241        if self.valid() {
242            // SAFETY: the same thread as the creator
243            Some(unsafe { self.get_unchecked_pinned() })
244        } else {
245            None
246        }
247    }
248
249    /// Returns a pinned mutable reference to the contained value, if valid.
250    #[inline]
251    pub fn get_pinned_mut(self: Pin<&mut Self>) -> Option<Pin<&mut T>> {
252        if self.valid() {
253            // SAFETY: the same thread as the creator
254            Some(unsafe { self.get_unchecked_pinned_mut() })
255        } else {
256            None
257        }
258    }
259
260    /// Returns a tracker that can be used to check if the current thread is
261    /// the same as the creator thread.
262    #[inline]
263    pub fn tracker(&self) -> SendWrapper<()> {
264        SendWrapper {
265            data: ManuallyDrop::new(()),
266            thread_id: self.thread_id,
267        }
268    }
269}
270
271unsafe impl<T> Send for SendWrapper<T> {}
272unsafe impl<T> Sync for SendWrapper<T> {}
273
274impl<T> Drop for SendWrapper<T> {
275    /// Drops the contained value.
276    ///
277    /// # Panics
278    ///
279    /// Dropping panics if it is done from a different thread than the one the
280    /// `SendWrapper<T>` instance has been created with.
281    ///
282    /// Exceptions:
283    /// - There is no extra panic if the thread is already panicking/unwinding.
284    ///   This is because otherwise there would be double panics (usually
285    ///   resulting in an abort) when dereferencing from a wrong thread.
286    /// - If `T` has a trivial drop ([`needs_drop::<T>()`] is false) then this
287    ///   method never panics.
288    ///
289    /// [`needs_drop::<T>()`]: std::mem::needs_drop
290    #[track_caller]
291    fn drop(&mut self) {
292        // If the drop is trivial (`needs_drop` = false), then dropping `T` can't access
293        // it and so it can be safely dropped on any thread.
294        if !mem::needs_drop::<T>() || self.valid() {
295            unsafe {
296                // Drop the inner value
297                //
298                // SAFETY:
299                // - We've just checked that it's valid to drop `T` on this thread
300                // - We only move out from `self.data` here and in drop, so `self.data` is
301                //   present
302                ManuallyDrop::drop(&mut self.data);
303            }
304        } else {
305            invalid_drop()
306        }
307    }
308}
309
310impl<T: fmt::Debug> fmt::Debug for SendWrapper<T> {
311    /// Formats the value using the given formatter.
312    ///
313    /// If the `SendWrapper<T>` is formatted from a different thread than the
314    /// one it was created on, the `data` field is shown as `"<invalid>"`
315    /// instead of causing a panic.
316    #[track_caller]
317    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
318        let mut f = f.debug_struct("SendWrapper");
319        if let Some(data) = self.get() {
320            f.field("data", data);
321        } else {
322            f.field("data", &"<invalid>");
323        }
324        f.field("thread_id", &self.thread_id).finish()
325    }
326}
327
328impl<T: Clone> Clone for SendWrapper<T> {
329    /// Returns a copy of the value.
330    ///
331    /// # Panics
332    ///
333    /// Cloning panics if it is done from a different thread than the one
334    /// the `SendWrapper<T>` instance has been created with.
335    #[track_caller]
336    fn clone(&self) -> Self {
337        Self::new(self.get().unwrap_or_else(|| invalid_deref()).clone())
338    }
339}
340
341#[cold]
342#[inline(never)]
343#[track_caller]
344fn invalid_deref() -> ! {
345    const DEREF_ERROR: &str = "Accessed SendWrapper<T> variable from a thread different to the \
346                               one it has been created with.";
347
348    panic!("{}", DEREF_ERROR)
349}
350
351#[cold]
352#[inline(never)]
353#[track_caller]
354#[cfg(feature = "futures")]
355fn invalid_poll() -> ! {
356    const POLL_ERROR: &str = "Polling SendWrapper<T> variable from a thread different to the one \
357                              it has been created with.";
358
359    panic!("{}", POLL_ERROR)
360}
361
362#[cold]
363#[inline(never)]
364#[track_caller]
365fn invalid_drop() {
366    const DROP_ERROR: &str = "Dropped SendWrapper<T> variable from a thread different to the one \
367                              it has been created with.";
368
369    if !thread::panicking() {
370        // panic because of dropping from wrong thread
371        // only do this while not unwinding (could be caused by deref from wrong thread)
372        panic!("{}", DROP_ERROR)
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use std::{
379        pin::Pin,
380        rc::Rc,
381        sync::{Arc, mpsc::channel},
382        thread,
383    };
384
385    use super::SendWrapper;
386
387    #[test]
388    fn get_and_get_mut_on_creator_thread_and_pinned_variants() {
389        let mut wrapper = SendWrapper::new(1_i32);
390
391        // On the creator thread, the plain accessors should return Some.
392        let r = wrapper.get();
393        assert!(r.is_some());
394        assert_eq!(*r.unwrap(), 1);
395
396        let r_mut = wrapper.get_mut();
397        assert!(r_mut.is_some());
398        *r_mut.unwrap() = 2;
399
400        // The change via get_mut should be visible via get as well.
401        let r_after = wrapper.get();
402        assert!(r_after.is_some());
403        assert_eq!(*r_after.unwrap(), 2);
404
405        // Pinned shared reference should also succeed on the creator thread.
406        let pinned = Pin::new(&wrapper);
407        let pinned_ref = pinned.get_pinned();
408        assert!(pinned_ref.is_some());
409        assert_eq!(*pinned_ref.unwrap(), 2);
410
411        // Pinned mutable reference should succeed and allow mutation.
412        let mut wrapper2 = SendWrapper::new(10_i32);
413        let pinned_mut = Pin::new(&mut wrapper2);
414        let pinned_mut_ref = pinned_mut.get_pinned_mut();
415        assert!(pinned_mut_ref.is_some());
416        *pinned_mut_ref.unwrap() = 11;
417
418        let after_mut = wrapper2.get();
419        assert!(after_mut.is_some());
420        assert_eq!(*after_mut.unwrap(), 11);
421    }
422
423    #[test]
424    fn accessors_return_none_on_non_creator_thread() {
425        let mut wrapper = SendWrapper::new(123_i32);
426
427        // Move the wrapper to another thread; that thread is not the creator.
428        let handle = thread::spawn(move || {
429            // Plain accessors should return None on non-creator thread.
430            assert!(wrapper.get().is_none());
431            assert!(wrapper.get_mut().is_none());
432
433            // Pinned accessors should also return None on non-creator thread.
434            let pinned = Pin::new(&wrapper);
435            assert!(pinned.get_pinned().is_none());
436
437            let mut wrapper = wrapper;
438            let pinned_mut = Pin::new(&mut wrapper);
439            assert!(pinned_mut.get_pinned_mut().is_none());
440        });
441
442        handle.join().unwrap();
443    }
444
445    #[test]
446    fn test_valid() {
447        let (sender, receiver) = channel();
448        let w = SendWrapper::new(Rc::new(42));
449        assert!(w.valid());
450        let t = thread::spawn(move || {
451            // move SendWrapper back to main thread, so it can be dropped from there
452            sender.send(w).unwrap();
453        });
454        let w2 = receiver.recv().unwrap();
455        assert!(w2.valid());
456        assert!(t.join().is_ok());
457    }
458
459    #[test]
460    fn test_invalid() {
461        let w = SendWrapper::new(Rc::new(42));
462        let t = thread::spawn(move || {
463            assert!(!w.valid());
464            w
465        });
466        let join_result = t.join();
467        assert!(join_result.is_ok());
468    }
469
470    #[test]
471    fn test_drop_panic() {
472        let w = SendWrapper::new(Rc::new(42));
473        let t = thread::spawn(move || {
474            drop(w);
475        });
476        let join_result = t.join();
477        assert!(join_result.is_err());
478    }
479
480    #[test]
481    fn test_take() {
482        let w = SendWrapper::new(Rc::new(42));
483        let inner: Rc<usize> = w.take();
484        assert_eq!(42, *inner);
485    }
486
487    #[test]
488    fn test_take_panic() {
489        let w = SendWrapper::new(Rc::new(42));
490        let t = thread::spawn(move || {
491            let _ = w.take();
492        });
493        assert!(t.join().is_err());
494    }
495    #[test]
496    fn test_sync() {
497        // Arc<T> can only be sent to another thread if T Sync
498        let arc = Arc::new(SendWrapper::new(42));
499        thread::spawn(move || {
500            let _ = arc;
501        });
502    }
503
504    #[test]
505    fn test_debug() {
506        let w = SendWrapper::new(Rc::new(42));
507        let info = format!("{:?}", w);
508        assert!(info.contains("SendWrapper {"));
509        assert!(info.contains("data: 42,"));
510        assert!(info.contains("thread_id: ThreadId("));
511    }
512
513    #[test]
514    fn test_debug_invalid() {
515        let w = SendWrapper::new(Rc::new(42));
516        let t = thread::spawn(move || {
517            let info = format!("{:?}", w);
518            assert!(info.contains("SendWrapper {"));
519            assert!(info.contains("data: \"<invalid>\","));
520            assert!(info.contains("thread_id: ThreadId("));
521            w
522        });
523        assert!(t.join().is_ok());
524    }
525
526    #[test]
527    fn test_clone() {
528        let w1 = SendWrapper::new(Rc::new(42));
529        let w2 = w1.clone();
530        assert_eq!(format!("{:?}", w1), format!("{:?}", w2));
531    }
532
533    #[test]
534    fn test_clone_panic() {
535        let w = SendWrapper::new(Rc::new(42));
536        let t = thread::spawn(move || {
537            let _ = w.clone();
538        });
539        assert!(t.join().is_err());
540    }
541}