Skip to main content

assert_unmoved/
assert_unmoved.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3use core::{
4    future::Future,
5    ops,
6    panic::Location,
7    pin::Pin,
8    task::{Context, Poll},
9};
10use std::thread;
11
12use pin_project_lite::pin_project;
13
14pin_project! {
15    /// A type that asserts that the underlying type is not moved after being pinned
16    /// and mutably accessed.
17    ///
18    /// See the [crate-level documentation](crate) for details.
19    #[project(!Unpin)]
20    #[derive(Debug)]
21    pub struct AssertUnmoved<T> {
22        #[pin]
23        inner: T,
24        this_addr: usize,
25        first_pinned_mutably_accessed_at: Option<&'static Location<'static>>,
26    }
27    impl<T> PinnedDrop for AssertUnmoved<T> {
28        /// # Panics
29        ///
30        /// Panics if this `AssertUnmoved` moved after being pinned and mutably accessed.
31        fn drop(this: Pin<&mut Self>) {
32            // If the thread is panicking then we can't panic again as that will
33            // cause the process to be aborted.
34            if !thread::panicking() && this.this_addr != 0 {
35                let cur_this = this.addr();
36                assert_eq!(
37                    this.this_addr,
38                    cur_this,
39                    "AssertUnmoved moved before drop\n\
40                     \tfirst pinned mutably accessed at {}\n",
41                    this.first_pinned_mutably_accessed_at.unwrap()
42                );
43            }
44        }
45    }
46}
47
48impl<T> AssertUnmoved<T> {
49    /// Creates a new `AssertUnmoved`.
50    #[must_use]
51    pub const fn new(inner: T) -> Self {
52        Self { inner, this_addr: 0, first_pinned_mutably_accessed_at: None }
53    }
54
55    /// Gets a reference to the underlying type.
56    ///
57    /// Unlike [`get_mut`](AssertUnmoved::get_mut) method, this method can always called.
58    ///
59    /// You can also access the underlying type via [`Deref`](std::ops::Deref) impl.
60    #[must_use]
61    pub const fn get_ref(&self) -> &T {
62        &self.inner
63    }
64
65    /// Gets a mutable reference to the underlying type.
66    ///
67    /// Note that this method can only be called before pinned since
68    /// `AssertUnmoved` is `!Unpin` (this is guaranteed by the type system!).
69    ///
70    /// # Panics
71    ///
72    /// Panics if this `AssertUnmoved` moved after being pinned and mutably accessed.
73    #[must_use]
74    #[track_caller]
75    pub fn get_mut(&mut self) -> &mut T {
76        if self.this_addr != 0 {
77            let cur_this = self.addr();
78            assert_eq!(
79                self.this_addr,
80                cur_this,
81                "AssertUnmoved moved after get_pin_mut call\n\
82                 \tfirst pinned mutably accessed at {}\n",
83                self.first_pinned_mutably_accessed_at.unwrap()
84            );
85        }
86        &mut self.inner
87    }
88
89    /// Gets a pinned mutable reference to the underlying type.
90    ///
91    /// # Panics
92    ///
93    /// Panics if this `AssertUnmoved` moved after being pinned and mutably accessed.
94    ///
95    /// # Examples
96    ///
97    /// Implement own [`Stream`] trait for `AssertUnmoved`.
98    ///
99    /// ```
100    /// use std::{
101    ///     pin::Pin,
102    ///     task::{Context, Poll},
103    /// };
104    ///
105    /// use assert_unmoved::AssertUnmoved;
106    ///
107    /// trait MyStream {
108    ///     type Item;
109    ///
110    ///     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>;
111    /// }
112    ///
113    /// impl<S: MyStream> MyStream for AssertUnmoved<S> {
114    ///     type Item = S::Item;
115    ///
116    ///     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
117    ///         self.get_pin_mut().poll_next(cx)
118    ///     }
119    /// }
120    /// ```
121    ///
122    /// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html
123    #[must_use]
124    #[track_caller]
125    pub fn get_pin_mut(mut self: Pin<&mut Self>) -> Pin<&mut T> {
126        let cur_this = self.addr();
127        if self.this_addr == 0 {
128            // First time being pinned and mutably accessed.
129            *self.as_mut().project().this_addr = cur_this;
130            *self.as_mut().project().first_pinned_mutably_accessed_at = Some(Location::caller());
131        } else {
132            assert_eq!(
133                self.this_addr,
134                cur_this,
135                "AssertUnmoved moved between get_pin_mut calls\n\
136                 \tfirst pinned mutably accessed at {}\n",
137                self.first_pinned_mutably_accessed_at.unwrap()
138            );
139        }
140        self.project().inner
141    }
142
143    fn addr(&self) -> usize {
144        self as *const Self as usize
145    }
146}
147
148impl<T> ops::Deref for AssertUnmoved<T> {
149    type Target = T;
150
151    fn deref(&self) -> &Self::Target {
152        self.get_ref()
153    }
154}
155
156impl<T> From<T> for AssertUnmoved<T> {
157    /// Converts a `T` into a `AssertUnmoved<T>`.
158    ///
159    /// This is equivalent to [`AssertUnmoved::new`].
160    fn from(inner: T) -> Self {
161        Self::new(inner)
162    }
163}
164
165impl<T: Default> Default for AssertUnmoved<T> {
166    /// Creates a new `AssertUnmoved`, with the default value for `T`.
167    ///
168    /// This is equivalent to [`AssertUnmoved::new(T::default())`](AssertUnmoved::new).
169    fn default() -> Self {
170        Self::new(T::default())
171    }
172}
173
174impl<F: Future> Future for AssertUnmoved<F> {
175    type Output = F::Output;
176
177    #[track_caller]
178    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179        self.get_pin_mut().poll(cx)
180    }
181}
182
183#[cfg(feature = "futures03")]
184mod futures03 {
185    use core::{
186        pin::Pin,
187        task::{Context, Poll},
188    };
189
190    use futures_core::{
191        future::FusedFuture,
192        stream::{FusedStream, Stream},
193    };
194    use futures_io as io;
195    use futures_sink::Sink;
196
197    use super::AssertUnmoved;
198
199    impl<F: FusedFuture> FusedFuture for AssertUnmoved<F> {
200        fn is_terminated(&self) -> bool {
201            self.get_ref().is_terminated()
202        }
203    }
204
205    impl<S: Stream> Stream for AssertUnmoved<S> {
206        type Item = S::Item;
207
208        #[track_caller]
209        fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
210            self.get_pin_mut().poll_next(cx)
211        }
212
213        fn size_hint(&self) -> (usize, Option<usize>) {
214            self.get_ref().size_hint()
215        }
216    }
217
218    impl<S: FusedStream> FusedStream for AssertUnmoved<S> {
219        fn is_terminated(&self) -> bool {
220            self.get_ref().is_terminated()
221        }
222    }
223
224    impl<S: Sink<Item>, Item> Sink<Item> for AssertUnmoved<S> {
225        type Error = S::Error;
226
227        #[track_caller]
228        fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
229            self.get_pin_mut().poll_ready(cx)
230        }
231
232        #[track_caller]
233        fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
234            self.get_pin_mut().start_send(item)
235        }
236
237        #[track_caller]
238        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
239            self.get_pin_mut().poll_flush(cx)
240        }
241
242        #[track_caller]
243        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
244            self.get_pin_mut().poll_close(cx)
245        }
246    }
247
248    impl<R: io::AsyncRead> io::AsyncRead for AssertUnmoved<R> {
249        #[track_caller]
250        fn poll_read(
251            self: Pin<&mut Self>,
252            cx: &mut Context<'_>,
253            buf: &mut [u8],
254        ) -> Poll<io::Result<usize>> {
255            self.get_pin_mut().poll_read(cx, buf)
256        }
257
258        #[track_caller]
259        fn poll_read_vectored(
260            self: Pin<&mut Self>,
261            cx: &mut Context<'_>,
262            bufs: &mut [io::IoSliceMut<'_>],
263        ) -> Poll<io::Result<usize>> {
264            self.get_pin_mut().poll_read_vectored(cx, bufs)
265        }
266    }
267
268    impl<W: io::AsyncWrite> io::AsyncWrite for AssertUnmoved<W> {
269        #[track_caller]
270        fn poll_write(
271            self: Pin<&mut Self>,
272            cx: &mut Context<'_>,
273            buf: &[u8],
274        ) -> Poll<io::Result<usize>> {
275            self.get_pin_mut().poll_write(cx, buf)
276        }
277
278        #[track_caller]
279        fn poll_write_vectored(
280            self: Pin<&mut Self>,
281            cx: &mut Context<'_>,
282            bufs: &[io::IoSlice<'_>],
283        ) -> Poll<io::Result<usize>> {
284            self.get_pin_mut().poll_write_vectored(cx, bufs)
285        }
286
287        #[track_caller]
288        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
289            self.get_pin_mut().poll_flush(cx)
290        }
291
292        #[track_caller]
293        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
294            self.get_pin_mut().poll_close(cx)
295        }
296    }
297
298    impl<S: io::AsyncSeek> io::AsyncSeek for AssertUnmoved<S> {
299        #[track_caller]
300        fn poll_seek(
301            self: Pin<&mut Self>,
302            cx: &mut Context<'_>,
303            pos: io::SeekFrom,
304        ) -> Poll<io::Result<u64>> {
305            self.get_pin_mut().poll_seek(cx, pos)
306        }
307    }
308
309    impl<R: io::AsyncBufRead> io::AsyncBufRead for AssertUnmoved<R> {
310        #[track_caller]
311        fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
312            self.get_pin_mut().poll_fill_buf(cx)
313        }
314
315        #[track_caller]
316        fn consume(self: Pin<&mut Self>, amt: usize) {
317            self.get_pin_mut().consume(amt);
318        }
319    }
320}
321
322#[cfg(feature = "tokio02")]
323mod tokio02 {
324    use core::{
325        mem::MaybeUninit,
326        pin::Pin,
327        task::{Context, Poll},
328    };
329    use std::io;
330
331    use bytes05::{Buf, BufMut};
332    use tokio02_crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite};
333
334    use super::AssertUnmoved;
335
336    impl<R: AsyncRead> AsyncRead for AssertUnmoved<R> {
337        unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
338            // SAFETY: The safety contract must be upheld by the caller.
339            unsafe { self.get_ref().prepare_uninitialized_buffer(buf) }
340        }
341
342        #[track_caller]
343        fn poll_read(
344            self: Pin<&mut Self>,
345            cx: &mut Context<'_>,
346            buf: &mut [u8],
347        ) -> Poll<io::Result<usize>> {
348            self.get_pin_mut().poll_read(cx, buf)
349        }
350
351        #[track_caller]
352        fn poll_read_buf<B: BufMut>(
353            self: Pin<&mut Self>,
354            cx: &mut Context<'_>,
355            buf: &mut B,
356        ) -> Poll<io::Result<usize>>
357        where
358            Self: Sized,
359        {
360            self.get_pin_mut().poll_read_buf(cx, buf)
361        }
362    }
363
364    impl<W: AsyncWrite> AsyncWrite for AssertUnmoved<W> {
365        #[track_caller]
366        fn poll_write(
367            self: Pin<&mut Self>,
368            cx: &mut Context<'_>,
369            buf: &[u8],
370        ) -> Poll<io::Result<usize>> {
371            self.get_pin_mut().poll_write(cx, buf)
372        }
373
374        #[track_caller]
375        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
376            self.get_pin_mut().poll_flush(cx)
377        }
378
379        #[track_caller]
380        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
381            self.get_pin_mut().poll_shutdown(cx)
382        }
383
384        #[track_caller]
385        fn poll_write_buf<B: Buf>(
386            self: Pin<&mut Self>,
387            cx: &mut Context<'_>,
388            buf: &mut B,
389        ) -> Poll<Result<usize, io::Error>>
390        where
391            Self: Sized,
392        {
393            self.get_pin_mut().poll_write_buf(cx, buf)
394        }
395    }
396
397    impl<S: AsyncSeek> AsyncSeek for AssertUnmoved<S> {
398        #[track_caller]
399        fn start_seek(
400            self: Pin<&mut Self>,
401            cx: &mut Context<'_>,
402            pos: io::SeekFrom,
403        ) -> Poll<io::Result<()>> {
404            self.get_pin_mut().start_seek(cx, pos)
405        }
406
407        #[track_caller]
408        fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
409            self.get_pin_mut().poll_complete(cx)
410        }
411    }
412
413    impl<R: AsyncBufRead> AsyncBufRead for AssertUnmoved<R> {
414        #[track_caller]
415        fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
416            self.get_pin_mut().poll_fill_buf(cx)
417        }
418
419        #[track_caller]
420        fn consume(self: Pin<&mut Self>, amt: usize) {
421            self.get_pin_mut().consume(amt);
422        }
423    }
424}
425
426#[cfg(feature = "tokio03")]
427mod tokio03 {
428    use core::{
429        pin::Pin,
430        task::{Context, Poll},
431    };
432
433    use tokio03_crate::io;
434
435    use super::AssertUnmoved;
436
437    impl<R: io::AsyncRead> io::AsyncRead for AssertUnmoved<R> {
438        #[track_caller]
439        fn poll_read(
440            self: Pin<&mut Self>,
441            cx: &mut Context<'_>,
442            buf: &mut io::ReadBuf<'_>,
443        ) -> Poll<io::Result<()>> {
444            self.get_pin_mut().poll_read(cx, buf)
445        }
446    }
447
448    impl<W: io::AsyncWrite> io::AsyncWrite for AssertUnmoved<W> {
449        #[track_caller]
450        fn poll_write(
451            self: Pin<&mut Self>,
452            cx: &mut Context<'_>,
453            buf: &[u8],
454        ) -> Poll<io::Result<usize>> {
455            self.get_pin_mut().poll_write(cx, buf)
456        }
457
458        #[track_caller]
459        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
460            self.get_pin_mut().poll_flush(cx)
461        }
462
463        #[track_caller]
464        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
465            self.get_pin_mut().poll_shutdown(cx)
466        }
467    }
468
469    impl<S: io::AsyncSeek> io::AsyncSeek for AssertUnmoved<S> {
470        #[track_caller]
471        fn start_seek(self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> {
472            self.get_pin_mut().start_seek(pos)
473        }
474
475        #[track_caller]
476        fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
477            self.get_pin_mut().poll_complete(cx)
478        }
479    }
480
481    impl<R: io::AsyncBufRead> io::AsyncBufRead for AssertUnmoved<R> {
482        #[track_caller]
483        fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
484            self.get_pin_mut().poll_fill_buf(cx)
485        }
486
487        #[track_caller]
488        fn consume(self: Pin<&mut Self>, amt: usize) {
489            self.get_pin_mut().consume(amt);
490        }
491    }
492}
493
494#[cfg(feature = "tokio1")]
495mod tokio1 {
496    use core::{
497        pin::Pin,
498        task::{Context, Poll},
499    };
500
501    use tokio1_crate::io;
502
503    use super::AssertUnmoved;
504
505    impl<R: io::AsyncRead> io::AsyncRead for AssertUnmoved<R> {
506        #[track_caller]
507        fn poll_read(
508            self: Pin<&mut Self>,
509            cx: &mut Context<'_>,
510            buf: &mut io::ReadBuf<'_>,
511        ) -> Poll<io::Result<()>> {
512            self.get_pin_mut().poll_read(cx, buf)
513        }
514    }
515
516    impl<W: io::AsyncWrite> io::AsyncWrite for AssertUnmoved<W> {
517        #[track_caller]
518        fn poll_write(
519            self: Pin<&mut Self>,
520            cx: &mut Context<'_>,
521            buf: &[u8],
522        ) -> Poll<io::Result<usize>> {
523            self.get_pin_mut().poll_write(cx, buf)
524        }
525
526        #[track_caller]
527        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
528            self.get_pin_mut().poll_flush(cx)
529        }
530
531        #[track_caller]
532        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
533            self.get_pin_mut().poll_shutdown(cx)
534        }
535
536        #[track_caller]
537        fn poll_write_vectored(
538            self: Pin<&mut Self>,
539            cx: &mut Context<'_>,
540            bufs: &[std::io::IoSlice<'_>],
541        ) -> Poll<Result<usize, io::Error>> {
542            self.get_pin_mut().poll_write_vectored(cx, bufs)
543        }
544
545        fn is_write_vectored(&self) -> bool {
546            self.get_ref().is_write_vectored()
547        }
548    }
549
550    impl<S: io::AsyncSeek> io::AsyncSeek for AssertUnmoved<S> {
551        #[track_caller]
552        fn start_seek(self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> {
553            self.get_pin_mut().start_seek(pos)
554        }
555
556        #[track_caller]
557        fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
558            self.get_pin_mut().poll_complete(cx)
559        }
560    }
561
562    impl<R: io::AsyncBufRead> io::AsyncBufRead for AssertUnmoved<R> {
563        #[track_caller]
564        fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
565            self.get_pin_mut().poll_fill_buf(cx)
566        }
567
568        #[track_caller]
569        fn consume(self: Pin<&mut Self>, amt: usize) {
570            self.get_pin_mut().consume(amt);
571        }
572    }
573}