compio_send_wrapper/lib.rs
1// Copyright 2017 Thomas Keh.
2// Copyright 2024 compio-rs
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! This [Rust] library implements a wrapper type called [`SendWrapper`] which
11//! allows you to move around non-[`Send`] types between threads, as long as you
12//! access the contained value only from within the original thread. You also
13//! have to make sure that the wrapper is dropped from within the original
14//! thread. If any of these constraints is violated, a panic occurs.
15//! [`SendWrapper<T>`] implements [`Send`] and [`Sync`] for any type `T`.
16//!
17//! # Examples
18//!
19//! ```rust
20//! use std::{rc::Rc, sync::mpsc::channel, thread};
21//!
22//! use compio_send_wrapper::SendWrapper;
23//!
24//! // Rc is a non-Send type.
25//! let value = Rc::new(42);
26//!
27//! // We now wrap the value with `SendWrapper` (value is moved inside).
28//! let wrapped_value = SendWrapper::new(value);
29//!
30//! // A channel allows us to move the wrapped value between threads.
31//! let (sender, receiver) = channel();
32//!
33//! let t = thread::spawn(move || {
34//! // This would panic (because of accessing in the wrong thread):
35//! // let value = wrapped_value.get().unwrap();
36//!
37//! // Move SendWrapper back to main thread, so it can be dropped from there.
38//! // If you leave this out the thread will panic because of dropping from wrong thread.
39//! sender.send(wrapped_value).unwrap();
40//! });
41//!
42//! let wrapped_value = receiver.recv().unwrap();
43//!
44//! // Now you can use the value again.
45//! let value = wrapped_value.get().unwrap();
46//!
47//! let mut wrapped_value = wrapped_value;
48//!
49//! // You can also get a mutable reference to the value.
50//! let value = wrapped_value.get_mut().unwrap();
51//! ```
52//!
53//! # Features
54//!
55//! This crate exposes several optional features:
56//!
57//! - `futures`: Enables [`Future`] and [`Stream`] implementations for
58//! [`SendWrapper`].
59//! - `current_thread_id`: Uses the unstable [`std::thread::current_id`] API (on
60//! nightly Rust) to track the originating thread more efficiently.
61//! - `nightly`: Enables nightly-only, experimental functionality used by this
62//! crate (including support for `current_thread_id` as configured in
63//! `Cargo.toml`).
64//!
65//! You can enable them in `Cargo.toml` like so:
66//!
67//! ```toml
68//! compio-send-wrapper = { version = "...", features = ["futures"] }
69//! # or, for example:
70//! # compio-send-wrapper = { version = "...", features = ["futures", "current_thread_id"] }
71//! ```
72//!
73//! # License
74//!
75//! `compio-send-wrapper` is distributed under the terms of both the MIT license
76//! and the Apache License (Version 2.0).
77//!
78//! See LICENSE-APACHE.txt, and LICENSE-MIT.txt for details.
79//!
80//! [Rust]: https://www.rust-lang.org
81//! [`Future`]: std::future::Future
82//! [`Stream`]: futures_core::Stream
83// To build docs locally use `RUSTDOCFLAGS="--cfg docsrs" cargo doc --open --all-features`
84#![cfg_attr(docsrs, feature(doc_cfg))]
85#![cfg_attr(feature = "current_thread_id", feature(current_thread_id))]
86#![warn(missing_docs)]
87
88#[cfg(feature = "futures")]
89#[cfg_attr(docsrs, doc(cfg(feature = "futures")))]
90mod futures;
91
92#[cfg(feature = "current_thread_id")]
93use std::thread::current_id;
94use std::{
95 fmt,
96 mem::{self, ManuallyDrop},
97 pin::Pin,
98 thread::{self, ThreadId},
99};
100
101#[cfg(not(feature = "current_thread_id"))]
102mod imp {
103 use std::{
104 cell::Cell,
105 thread::{self, ThreadId},
106 };
107 thread_local! {
108 static THREAD_ID: Cell<ThreadId> = Cell::new(thread::current().id());
109 }
110
111 pub fn current_id() -> ThreadId {
112 THREAD_ID.get()
113 }
114}
115
116#[cfg(not(feature = "current_thread_id"))]
117use imp::current_id;
118
119/// A wrapper which allows you to move around non-[`Send`]-types between
120/// threads, as long as you access the contained value only from within the
121/// original thread and make sure that it is dropped from within the original
122/// thread.
123pub struct SendWrapper<T> {
124 data: ManuallyDrop<T>,
125 thread_id: ThreadId,
126}
127
128impl<T> SendWrapper<T> {
129 /// Create a `SendWrapper<T>` wrapper around a value of type `T`.
130 /// The wrapper takes ownership of the value.
131 #[inline]
132 pub fn new(data: T) -> SendWrapper<T> {
133 SendWrapper {
134 data: ManuallyDrop::new(data),
135 thread_id: current_id(),
136 }
137 }
138
139 /// Returns `true` if the value can be safely accessed from within the
140 /// current thread.
141 #[inline]
142 pub fn valid(&self) -> bool {
143 self.thread_id == current_id()
144 }
145
146 /// Takes the value out of the `SendWrapper<T>`.
147 ///
148 /// # Safety
149 ///
150 /// The caller should be in the same thread as the creator.
151 pub unsafe fn take_unchecked(self) -> T {
152 // Prevent drop() from being called, as it would drop `self.data` twice
153 let mut this = ManuallyDrop::new(self);
154
155 // Safety:
156 // - The caller of this unsafe function guarantees that it's valid to access `T`
157 // from the current thread (the safe `take` method enforces this precondition
158 // before calling `take_unchecked`).
159 // - We only move out from `self.data` here and in drop, so `self.data` is
160 // present
161 unsafe { ManuallyDrop::take(&mut this.data) }
162 }
163
164 /// Takes the value out of the `SendWrapper<T>`.
165 ///
166 /// # Panics
167 ///
168 /// Panics if it is called from a different thread than the one the
169 /// `SendWrapper<T>` instance has been created with.
170 #[track_caller]
171 pub fn take(self) -> T {
172 if self.valid() {
173 // SAFETY: the same thread as the creator
174 unsafe { self.take_unchecked() }
175 } else {
176 invalid_deref()
177 }
178 }
179
180 /// Returns a reference to the contained value.
181 ///
182 /// # Safety
183 ///
184 /// The caller should be in the same thread as the creator.
185 #[inline]
186 pub unsafe fn get_unchecked(&self) -> &T {
187 &self.data
188 }
189
190 /// Returns a mutable reference to the contained value.
191 ///
192 /// # Safety
193 ///
194 /// The caller should be in the same thread as the creator.
195 #[inline]
196 pub unsafe fn get_unchecked_mut(&mut self) -> &mut T {
197 &mut self.data
198 }
199
200 /// Returns a pinned reference to the contained value.
201 ///
202 /// # Safety
203 ///
204 /// The caller should be in the same thread as the creator.
205 #[inline]
206 pub unsafe fn get_unchecked_pinned(self: Pin<&Self>) -> Pin<&T> {
207 // SAFETY: as long as `SendWrapper` is pinned, the inner data is pinned too.
208 unsafe { self.map_unchecked(|s| &*s.data) }
209 }
210
211 /// Returns a pinned mutable reference to the contained value.
212 ///
213 /// # Safety
214 ///
215 /// The caller should be in the same thread as the creator.
216 #[inline]
217 pub unsafe fn get_unchecked_pinned_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
218 // SAFETY: as long as `SendWrapper` is pinned, the inner data is pinned too.
219 unsafe { self.map_unchecked_mut(|s| &mut *s.data) }
220 }
221
222 /// Returns a reference to the contained value, if valid.
223 #[inline]
224 pub fn get(&self) -> Option<&T> {
225 if self.valid() { Some(&self.data) } else { None }
226 }
227
228 /// Returns a mutable reference to the contained value, if valid.
229 #[inline]
230 pub fn get_mut(&mut self) -> Option<&mut T> {
231 if self.valid() {
232 Some(&mut self.data)
233 } else {
234 None
235 }
236 }
237
238 /// Returns a pinned reference to the contained value, if valid.
239 #[inline]
240 pub fn get_pinned(self: Pin<&Self>) -> Option<Pin<&T>> {
241 if self.valid() {
242 // SAFETY: the same thread as the creator
243 Some(unsafe { self.get_unchecked_pinned() })
244 } else {
245 None
246 }
247 }
248
249 /// Returns a pinned mutable reference to the contained value, if valid.
250 #[inline]
251 pub fn get_pinned_mut(self: Pin<&mut Self>) -> Option<Pin<&mut T>> {
252 if self.valid() {
253 // SAFETY: the same thread as the creator
254 Some(unsafe { self.get_unchecked_pinned_mut() })
255 } else {
256 None
257 }
258 }
259
260 /// Returns a tracker that can be used to check if the current thread is
261 /// the same as the creator thread.
262 #[inline]
263 pub fn tracker(&self) -> SendWrapper<()> {
264 SendWrapper {
265 data: ManuallyDrop::new(()),
266 thread_id: self.thread_id,
267 }
268 }
269}
270
271unsafe impl<T> Send for SendWrapper<T> {}
272unsafe impl<T> Sync for SendWrapper<T> {}
273
274impl<T> Drop for SendWrapper<T> {
275 /// Drops the contained value.
276 ///
277 /// # Panics
278 ///
279 /// Dropping panics if it is done from a different thread than the one the
280 /// `SendWrapper<T>` instance has been created with.
281 ///
282 /// Exceptions:
283 /// - There is no extra panic if the thread is already panicking/unwinding.
284 /// This is because otherwise there would be double panics (usually
285 /// resulting in an abort) when dereferencing from a wrong thread.
286 /// - If `T` has a trivial drop ([`needs_drop::<T>()`] is false) then this
287 /// method never panics.
288 ///
289 /// [`needs_drop::<T>()`]: std::mem::needs_drop
290 #[track_caller]
291 fn drop(&mut self) {
292 // If the drop is trivial (`needs_drop` = false), then dropping `T` can't access
293 // it and so it can be safely dropped on any thread.
294 if !mem::needs_drop::<T>() || self.valid() {
295 unsafe {
296 // Drop the inner value
297 //
298 // SAFETY:
299 // - We've just checked that it's valid to drop `T` on this thread
300 // - We only move out from `self.data` here and in drop, so `self.data` is
301 // present
302 ManuallyDrop::drop(&mut self.data);
303 }
304 } else {
305 invalid_drop()
306 }
307 }
308}
309
310impl<T: fmt::Debug> fmt::Debug for SendWrapper<T> {
311 /// Formats the value using the given formatter.
312 ///
313 /// If the `SendWrapper<T>` is formatted from a different thread than the
314 /// one it was created on, the `data` field is shown as `"<invalid>"`
315 /// instead of causing a panic.
316 #[track_caller]
317 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
318 let mut f = f.debug_struct("SendWrapper");
319 if let Some(data) = self.get() {
320 f.field("data", data);
321 } else {
322 f.field("data", &"<invalid>");
323 }
324 f.field("thread_id", &self.thread_id).finish()
325 }
326}
327
328impl<T: Clone> Clone for SendWrapper<T> {
329 /// Returns a copy of the value.
330 ///
331 /// # Panics
332 ///
333 /// Cloning panics if it is done from a different thread than the one
334 /// the `SendWrapper<T>` instance has been created with.
335 #[track_caller]
336 fn clone(&self) -> Self {
337 Self::new(self.get().unwrap_or_else(|| invalid_deref()).clone())
338 }
339}
340
341#[cold]
342#[inline(never)]
343#[track_caller]
344fn invalid_deref() -> ! {
345 const DEREF_ERROR: &str = "Accessed SendWrapper<T> variable from a thread different to the \
346 one it has been created with.";
347
348 panic!("{}", DEREF_ERROR)
349}
350
351#[cold]
352#[inline(never)]
353#[track_caller]
354#[cfg(feature = "futures")]
355fn invalid_poll() -> ! {
356 const POLL_ERROR: &str = "Polling SendWrapper<T> variable from a thread different to the one \
357 it has been created with.";
358
359 panic!("{}", POLL_ERROR)
360}
361
362#[cold]
363#[inline(never)]
364#[track_caller]
365fn invalid_drop() {
366 const DROP_ERROR: &str = "Dropped SendWrapper<T> variable from a thread different to the one \
367 it has been created with.";
368
369 if !thread::panicking() {
370 // panic because of dropping from wrong thread
371 // only do this while not unwinding (could be caused by deref from wrong thread)
372 panic!("{}", DROP_ERROR)
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use std::{
379 pin::Pin,
380 rc::Rc,
381 sync::{Arc, mpsc::channel},
382 thread,
383 };
384
385 use super::SendWrapper;
386
387 #[test]
388 fn get_and_get_mut_on_creator_thread_and_pinned_variants() {
389 let mut wrapper = SendWrapper::new(1_i32);
390
391 // On the creator thread, the plain accessors should return Some.
392 let r = wrapper.get();
393 assert!(r.is_some());
394 assert_eq!(*r.unwrap(), 1);
395
396 let r_mut = wrapper.get_mut();
397 assert!(r_mut.is_some());
398 *r_mut.unwrap() = 2;
399
400 // The change via get_mut should be visible via get as well.
401 let r_after = wrapper.get();
402 assert!(r_after.is_some());
403 assert_eq!(*r_after.unwrap(), 2);
404
405 // Pinned shared reference should also succeed on the creator thread.
406 let pinned = Pin::new(&wrapper);
407 let pinned_ref = pinned.get_pinned();
408 assert!(pinned_ref.is_some());
409 assert_eq!(*pinned_ref.unwrap(), 2);
410
411 // Pinned mutable reference should succeed and allow mutation.
412 let mut wrapper2 = SendWrapper::new(10_i32);
413 let pinned_mut = Pin::new(&mut wrapper2);
414 let pinned_mut_ref = pinned_mut.get_pinned_mut();
415 assert!(pinned_mut_ref.is_some());
416 *pinned_mut_ref.unwrap() = 11;
417
418 let after_mut = wrapper2.get();
419 assert!(after_mut.is_some());
420 assert_eq!(*after_mut.unwrap(), 11);
421 }
422
423 #[test]
424 fn accessors_return_none_on_non_creator_thread() {
425 let mut wrapper = SendWrapper::new(123_i32);
426
427 // Move the wrapper to another thread; that thread is not the creator.
428 let handle = thread::spawn(move || {
429 // Plain accessors should return None on non-creator thread.
430 assert!(wrapper.get().is_none());
431 assert!(wrapper.get_mut().is_none());
432
433 // Pinned accessors should also return None on non-creator thread.
434 let pinned = Pin::new(&wrapper);
435 assert!(pinned.get_pinned().is_none());
436
437 let mut wrapper = wrapper;
438 let pinned_mut = Pin::new(&mut wrapper);
439 assert!(pinned_mut.get_pinned_mut().is_none());
440 });
441
442 handle.join().unwrap();
443 }
444
445 #[test]
446 fn test_valid() {
447 let (sender, receiver) = channel();
448 let w = SendWrapper::new(Rc::new(42));
449 assert!(w.valid());
450 let t = thread::spawn(move || {
451 // move SendWrapper back to main thread, so it can be dropped from there
452 sender.send(w).unwrap();
453 });
454 let w2 = receiver.recv().unwrap();
455 assert!(w2.valid());
456 assert!(t.join().is_ok());
457 }
458
459 #[test]
460 fn test_invalid() {
461 let w = SendWrapper::new(Rc::new(42));
462 let t = thread::spawn(move || {
463 assert!(!w.valid());
464 w
465 });
466 let join_result = t.join();
467 assert!(join_result.is_ok());
468 }
469
470 #[test]
471 fn test_drop_panic() {
472 let w = SendWrapper::new(Rc::new(42));
473 let t = thread::spawn(move || {
474 drop(w);
475 });
476 let join_result = t.join();
477 assert!(join_result.is_err());
478 }
479
480 #[test]
481 fn test_take() {
482 let w = SendWrapper::new(Rc::new(42));
483 let inner: Rc<usize> = w.take();
484 assert_eq!(42, *inner);
485 }
486
487 #[test]
488 fn test_take_panic() {
489 let w = SendWrapper::new(Rc::new(42));
490 let t = thread::spawn(move || {
491 let _ = w.take();
492 });
493 assert!(t.join().is_err());
494 }
495 #[test]
496 fn test_sync() {
497 // Arc<T> can only be sent to another thread if T Sync
498 let arc = Arc::new(SendWrapper::new(42));
499 thread::spawn(move || {
500 let _ = arc;
501 });
502 }
503
504 #[test]
505 fn test_debug() {
506 let w = SendWrapper::new(Rc::new(42));
507 let info = format!("{:?}", w);
508 assert!(info.contains("SendWrapper {"));
509 assert!(info.contains("data: 42,"));
510 assert!(info.contains("thread_id: ThreadId("));
511 }
512
513 #[test]
514 fn test_debug_invalid() {
515 let w = SendWrapper::new(Rc::new(42));
516 let t = thread::spawn(move || {
517 let info = format!("{:?}", w);
518 assert!(info.contains("SendWrapper {"));
519 assert!(info.contains("data: \"<invalid>\","));
520 assert!(info.contains("thread_id: ThreadId("));
521 w
522 });
523 assert!(t.join().is_ok());
524 }
525
526 #[test]
527 fn test_clone() {
528 let w1 = SendWrapper::new(Rc::new(42));
529 let w2 = w1.clone();
530 assert_eq!(format!("{:?}", w1), format!("{:?}", w2));
531 }
532
533 #[test]
534 fn test_clone_panic() {
535 let w = SendWrapper::new(Rc::new(42));
536 let t = thread::spawn(move || {
537 let _ = w.clone();
538 });
539 assert!(t.join().is_err());
540 }
541}