compio_send_wrapper/lib.rs
1// Copyright 2017 Thomas Keh.
2// Copyright 2024 compio-rs
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! This [Rust] library implements a wrapper type called [`SendWrapper`] which
11//! allows you to move around non-[`Send`] types between threads, as long as you
12//! access the contained value only from within the original thread. You also
13//! have to make sure that the wrapper is dropped from within the original
14//! thread. If any of these constraints is violated, a panic occurs.
15//! [`SendWrapper<T>`] implements [`Send`] and [`Sync`] for any type `T`.
16//!
17//! # Examples
18//!
19//! ```rust
20//! use std::{rc::Rc, sync::mpsc::channel, thread};
21//!
22//! use compio_send_wrapper::SendWrapper;
23//!
24//! // Rc is a non-Send type.
25//! let value = Rc::new(42);
26//!
27//! // We now wrap the value with `SendWrapper` (value is moved inside).
28//! let wrapped_value = SendWrapper::new(value);
29//!
30//! // A channel allows us to move the wrapped value between threads.
31//! let (sender, receiver) = channel();
32//!
33//! let t = thread::spawn(move || {
34//! // This would panic (because of accessing in the wrong thread):
35//! // let value = wrapped_value.get().unwrap();
36//!
37//! // Move SendWrapper back to main thread, so it can be dropped from there.
38//! // If you leave this out the thread will panic because of dropping from wrong thread.
39//! sender.send(wrapped_value).unwrap();
40//! });
41//!
42//! let wrapped_value = receiver.recv().unwrap();
43//!
44//! // Now you can use the value again.
45//! let value = wrapped_value.get().unwrap();
46//!
47//! let mut wrapped_value = wrapped_value;
48//!
49//! // You can also get a mutable reference to the value.
50//! let value = wrapped_value.get_mut().unwrap();
51//! ```
52//!
53//! # Features
54//!
55//! This crate exposes several optional features:
56//!
57//! - `futures`: Enables [`Future`] and [`Stream`] implementations for
58//! [`SendWrapper`].
59//! - `current_thread_id`: Uses the unstable [`std::thread::current_id`] API (on
60//! nightly Rust) to track the originating thread more efficiently.
61//! - `nightly`: Enables nightly-only, experimental functionality used by this
62//! crate (including support for `current_thread_id` as configured in
63//! `Cargo.toml`).
64//!
65//! You can enable them in `Cargo.toml` like so:
66//!
67//! ```toml
68//! compio-send-wrapper = { version = "...", features = ["futures"] }
69//! # or, for example:
70//! # compio-send-wrapper = { version = "...", features = ["futures", "current_thread_id"] }
71//! ```
72//!
73//! # License
74//!
75//! `compio-send-wrapper` is distributed under the terms of both the MIT license
76//! and the Apache License (Version 2.0).
77//!
78//! See LICENSE-APACHE.txt, and LICENSE-MIT.txt for details.
79//!
80//! [Rust]: https://www.rust-lang.org
81//! [`Future`]: std::future::Future
82//! [`Stream`]: futures_core::Stream
83// To build docs locally use `RUSTDOCFLAGS="--cfg docsrs" cargo doc --open --all-features`
84#![cfg_attr(docsrs, feature(doc_cfg))]
85#![cfg_attr(
86 all(not(loom), feature = "current_thread_id"),
87 feature(current_thread_id)
88)]
89#![warn(missing_docs)]
90
91#[cfg(feature = "futures")]
92#[cfg_attr(docsrs, doc(cfg(feature = "futures")))]
93mod futures;
94
95use std::{
96 fmt,
97 mem::{self, ManuallyDrop},
98 pin::Pin,
99};
100
101cfg_if::cfg_if! {
102 if #[cfg(any(loom, not(feature = "current_thread_id")))] {
103 #[cfg(loom)]
104 use loom::{thread_local, cell::Cell, thread::{self, ThreadId}};
105 #[cfg(not(loom))]
106 use std::{thread_local, cell::Cell, thread::{self, ThreadId}};
107
108 thread_local! {
109 static THREAD_ID: Cell<ThreadId> = Cell::new(thread::current().id());
110 }
111
112 /// Get the current [`ThreadId`].
113 pub(crate) fn current_id() -> ThreadId {
114 THREAD_ID.with(|id| id.get())
115 }
116 } else {
117 use std::thread::{self, current_id, ThreadId};
118 }
119}
120
121/// A wrapper which allows you to move around non-[`Send`]-types between
122/// threads, as long as you access the contained value only from within the
123/// original thread and make sure that it is dropped from within the original
124/// thread.
125pub struct SendWrapper<T> {
126 data: ManuallyDrop<T>,
127 thread_id: ThreadId,
128}
129
130impl<T> SendWrapper<T> {
131 /// Create a `SendWrapper<T>` wrapper around a value of type `T`.
132 /// The wrapper takes ownership of the value.
133 #[inline]
134 pub fn new(data: T) -> SendWrapper<T> {
135 SendWrapper {
136 data: ManuallyDrop::new(data),
137 thread_id: current_id(),
138 }
139 }
140
141 /// Returns `true` if the value can be safely accessed from within the
142 /// current thread.
143 #[inline]
144 pub fn valid(&self) -> bool {
145 self.thread_id == current_id()
146 }
147
148 /// Takes the value out of the `SendWrapper<T>`.
149 ///
150 /// # Safety
151 ///
152 /// The caller should be in the same thread as the creator.
153 pub unsafe fn take_unchecked(self) -> T {
154 // Prevent drop() from being called, as it would drop `self.data` twice
155 let mut this = ManuallyDrop::new(self);
156
157 // Safety:
158 // - The caller of this unsafe function guarantees that it's valid to access `T`
159 // from the current thread (the safe `take` method enforces this precondition
160 // before calling `take_unchecked`).
161 // - We only move out from `self.data` here and in drop, so `self.data` is
162 // present
163 unsafe { ManuallyDrop::take(&mut this.data) }
164 }
165
166 /// Takes the value out of the `SendWrapper<T>`.
167 ///
168 /// # Panics
169 ///
170 /// Panics if it is called from a different thread than the one the
171 /// `SendWrapper<T>` instance has been created with.
172 #[track_caller]
173 pub fn take(self) -> T {
174 if self.valid() {
175 // SAFETY: the same thread as the creator
176 unsafe { self.take_unchecked() }
177 } else {
178 invalid_deref()
179 }
180 }
181
182 /// Returns a reference to the contained value.
183 ///
184 /// # Safety
185 ///
186 /// The caller should be in the same thread as the creator.
187 #[inline]
188 pub unsafe fn get_unchecked(&self) -> &T {
189 &self.data
190 }
191
192 /// Returns a mutable reference to the contained value.
193 ///
194 /// # Safety
195 ///
196 /// The caller should be in the same thread as the creator.
197 #[inline]
198 pub unsafe fn get_unchecked_mut(&mut self) -> &mut T {
199 &mut self.data
200 }
201
202 /// Returns a pinned reference to the contained value.
203 ///
204 /// # Safety
205 ///
206 /// The caller should be in the same thread as the creator.
207 #[inline]
208 pub unsafe fn get_unchecked_pinned(self: Pin<&Self>) -> Pin<&T> {
209 // SAFETY: as long as `SendWrapper` is pinned, the inner data is pinned too.
210 unsafe { self.map_unchecked(|s| &*s.data) }
211 }
212
213 /// Returns a pinned mutable reference to the contained value.
214 ///
215 /// # Safety
216 ///
217 /// The caller should be in the same thread as the creator.
218 #[inline]
219 pub unsafe fn get_unchecked_pinned_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
220 // SAFETY: as long as `SendWrapper` is pinned, the inner data is pinned too.
221 unsafe { self.map_unchecked_mut(|s| &mut *s.data) }
222 }
223
224 /// Returns a reference to the contained value, if valid.
225 #[inline]
226 pub fn get(&self) -> Option<&T> {
227 if self.valid() { Some(&self.data) } else { None }
228 }
229
230 /// Returns a mutable reference to the contained value, if valid.
231 #[inline]
232 pub fn get_mut(&mut self) -> Option<&mut T> {
233 if self.valid() {
234 Some(&mut self.data)
235 } else {
236 None
237 }
238 }
239
240 /// Returns a pinned reference to the contained value, if valid.
241 #[inline]
242 pub fn get_pinned(self: Pin<&Self>) -> Option<Pin<&T>> {
243 if self.valid() {
244 // SAFETY: the same thread as the creator
245 Some(unsafe { self.get_unchecked_pinned() })
246 } else {
247 None
248 }
249 }
250
251 /// Returns a pinned mutable reference to the contained value, if valid.
252 #[inline]
253 pub fn get_pinned_mut(self: Pin<&mut Self>) -> Option<Pin<&mut T>> {
254 if self.valid() {
255 // SAFETY: the same thread as the creator
256 Some(unsafe { self.get_unchecked_pinned_mut() })
257 } else {
258 None
259 }
260 }
261
262 /// Returns a tracker that can be used to check if the current thread is
263 /// the same as the creator thread.
264 #[inline]
265 pub fn tracker(&self) -> SendWrapper<()> {
266 SendWrapper {
267 data: ManuallyDrop::new(()),
268 thread_id: self.thread_id,
269 }
270 }
271}
272
273unsafe impl<T> Send for SendWrapper<T> {}
274unsafe impl<T> Sync for SendWrapper<T> {}
275
276impl<T> Drop for SendWrapper<T> {
277 /// Drops the contained value.
278 ///
279 /// # Panics
280 ///
281 /// Dropping panics if it is done from a different thread than the one the
282 /// `SendWrapper<T>` instance has been created with.
283 ///
284 /// Exceptions:
285 /// - There is no extra panic if the thread is already panicking/unwinding.
286 /// This is because otherwise there would be double panics (usually
287 /// resulting in an abort) when dereferencing from a wrong thread.
288 /// - If `T` has a trivial drop ([`needs_drop::<T>()`] is false) then this
289 /// method never panics.
290 ///
291 /// [`needs_drop::<T>()`]: std::mem::needs_drop
292 #[track_caller]
293 fn drop(&mut self) {
294 // If the drop is trivial (`needs_drop` = false), then dropping `T` can't access
295 // it and so it can be safely dropped on any thread.
296 if !mem::needs_drop::<T>() || self.valid() {
297 unsafe {
298 // Drop the inner value
299 //
300 // SAFETY:
301 // - We've just checked that it's valid to drop `T` on this thread
302 // - We only move out from `self.data` here and in drop, so `self.data` is
303 // present
304 ManuallyDrop::drop(&mut self.data);
305 }
306 } else {
307 invalid_drop()
308 }
309 }
310}
311
312impl<T: fmt::Debug> fmt::Debug for SendWrapper<T> {
313 /// Formats the value using the given formatter.
314 ///
315 /// If the `SendWrapper<T>` is formatted from a different thread than the
316 /// one it was created on, the `data` field is shown as `"<invalid>"`
317 /// instead of causing a panic.
318 #[track_caller]
319 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
320 let mut f = f.debug_struct("SendWrapper");
321 if let Some(data) = self.get() {
322 f.field("data", data);
323 } else {
324 f.field("data", &"<invalid>");
325 }
326 f.field("thread_id", &self.thread_id).finish()
327 }
328}
329
330impl<T: Clone> Clone for SendWrapper<T> {
331 /// Returns a copy of the value.
332 ///
333 /// # Panics
334 ///
335 /// Cloning panics if it is done from a different thread than the one
336 /// the `SendWrapper<T>` instance has been created with.
337 #[track_caller]
338 fn clone(&self) -> Self {
339 Self::new(self.get().unwrap_or_else(|| invalid_deref()).clone())
340 }
341}
342
343#[cold]
344#[inline(never)]
345#[track_caller]
346fn invalid_deref() -> ! {
347 const DEREF_ERROR: &str = "Accessed SendWrapper<T> variable from a thread different to the \
348 one it has been created with.";
349
350 panic!("{}", DEREF_ERROR)
351}
352
353#[cold]
354#[inline(never)]
355#[track_caller]
356#[cfg(feature = "futures")]
357fn invalid_poll() -> ! {
358 const POLL_ERROR: &str = "Polling SendWrapper<T> variable from a thread different to the one \
359 it has been created with.";
360
361 panic!("{}", POLL_ERROR)
362}
363
364#[cold]
365#[inline(never)]
366#[track_caller]
367fn invalid_drop() {
368 const DROP_ERROR: &str = "Dropped SendWrapper<T> variable from a thread different to the one \
369 it has been created with.";
370
371 if !thread::panicking() {
372 // panic because of dropping from wrong thread
373 // only do this while not unwinding (could be caused by deref from wrong thread)
374 panic!("{}", DROP_ERROR)
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use std::{
381 pin::Pin,
382 rc::Rc,
383 sync::{Arc, mpsc::channel},
384 thread,
385 };
386
387 use super::SendWrapper;
388
389 #[test]
390 fn get_and_get_mut_on_creator_thread_and_pinned_variants() {
391 let mut wrapper = SendWrapper::new(1_i32);
392
393 // On the creator thread, the plain accessors should return Some.
394 let r = wrapper.get();
395 assert!(r.is_some());
396 assert_eq!(*r.unwrap(), 1);
397
398 let r_mut = wrapper.get_mut();
399 assert!(r_mut.is_some());
400 *r_mut.unwrap() = 2;
401
402 // The change via get_mut should be visible via get as well.
403 let r_after = wrapper.get();
404 assert!(r_after.is_some());
405 assert_eq!(*r_after.unwrap(), 2);
406
407 // Pinned shared reference should also succeed on the creator thread.
408 let pinned = Pin::new(&wrapper);
409 let pinned_ref = pinned.get_pinned();
410 assert!(pinned_ref.is_some());
411 assert_eq!(*pinned_ref.unwrap(), 2);
412
413 // Pinned mutable reference should succeed and allow mutation.
414 let mut wrapper2 = SendWrapper::new(10_i32);
415 let pinned_mut = Pin::new(&mut wrapper2);
416 let pinned_mut_ref = pinned_mut.get_pinned_mut();
417 assert!(pinned_mut_ref.is_some());
418 *pinned_mut_ref.unwrap() = 11;
419
420 let after_mut = wrapper2.get();
421 assert!(after_mut.is_some());
422 assert_eq!(*after_mut.unwrap(), 11);
423 }
424
425 #[test]
426 fn accessors_return_none_on_non_creator_thread() {
427 let mut wrapper = SendWrapper::new(123_i32);
428
429 // Move the wrapper to another thread; that thread is not the creator.
430 let handle = thread::spawn(move || {
431 // Plain accessors should return None on non-creator thread.
432 assert!(wrapper.get().is_none());
433 assert!(wrapper.get_mut().is_none());
434
435 // Pinned accessors should also return None on non-creator thread.
436 let pinned = Pin::new(&wrapper);
437 assert!(pinned.get_pinned().is_none());
438
439 let mut wrapper = wrapper;
440 let pinned_mut = Pin::new(&mut wrapper);
441 assert!(pinned_mut.get_pinned_mut().is_none());
442 });
443
444 handle.join().unwrap();
445 }
446
447 #[test]
448 fn test_valid() {
449 let (sender, receiver) = channel();
450 let w = SendWrapper::new(Rc::new(42));
451 assert!(w.valid());
452 let t = thread::spawn(move || {
453 // move SendWrapper back to main thread, so it can be dropped from there
454 sender.send(w).unwrap();
455 });
456 let w2 = receiver.recv().unwrap();
457 assert!(w2.valid());
458 assert!(t.join().is_ok());
459 }
460
461 #[test]
462 fn test_invalid() {
463 let w = SendWrapper::new(Rc::new(42));
464 let t = thread::spawn(move || {
465 assert!(!w.valid());
466 w
467 });
468 let join_result = t.join();
469 assert!(join_result.is_ok());
470 }
471
472 #[test]
473 fn test_drop_panic() {
474 let w = SendWrapper::new(Rc::new(42));
475 let t = thread::spawn(move || {
476 drop(w);
477 });
478 let join_result = t.join();
479 assert!(join_result.is_err());
480 }
481
482 #[test]
483 fn test_take() {
484 let w = SendWrapper::new(Rc::new(42));
485 let inner: Rc<usize> = w.take();
486 assert_eq!(42, *inner);
487 }
488
489 #[test]
490 fn test_take_panic() {
491 let w = SendWrapper::new(Rc::new(42));
492 let t = thread::spawn(move || {
493 let _ = w.take();
494 });
495 assert!(t.join().is_err());
496 }
497 #[test]
498 fn test_sync() {
499 // Arc<T> can only be sent to another thread if T Sync
500 let arc = Arc::new(SendWrapper::new(42));
501 thread::spawn(move || {
502 let _ = arc;
503 });
504 }
505
506 #[test]
507 fn test_debug() {
508 let w = SendWrapper::new(Rc::new(42));
509 let info = format!("{:?}", w);
510 assert!(info.contains("SendWrapper {"));
511 assert!(info.contains("data: 42,"));
512 assert!(info.contains("thread_id: ThreadId("));
513 }
514
515 #[test]
516 fn test_debug_invalid() {
517 let w = SendWrapper::new(Rc::new(42));
518 let t = thread::spawn(move || {
519 let info = format!("{:?}", w);
520 assert!(info.contains("SendWrapper {"));
521 assert!(info.contains("data: \"<invalid>\","));
522 assert!(info.contains("thread_id: ThreadId("));
523 w
524 });
525 assert!(t.join().is_ok());
526 }
527
528 #[test]
529 fn test_clone() {
530 let w1 = SendWrapper::new(Rc::new(42));
531 let w2 = w1.clone();
532 assert_eq!(format!("{:?}", w1), format!("{:?}", w2));
533 }
534
535 #[test]
536 fn test_clone_panic() {
537 let w = SendWrapper::new(Rc::new(42));
538 let t = thread::spawn(move || {
539 let _ = w.clone();
540 });
541 assert!(t.join().is_err());
542 }
543}