assert_unmoved/
assert_unmoved.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3use core::{
4    future::Future,
5    ops,
6    panic::Location,
7    pin::Pin,
8    task::{Context, Poll},
9};
10use std::thread;
11
12use pin_project_lite::pin_project;
13
14pin_project! {
15    /// A type that asserts that the underlying type is not moved after being pinned
16    /// and mutably accessed.
17    ///
18    /// See crate level documentation for details.
19    #[project(!Unpin)]
20    #[derive(Debug)]
21    pub struct AssertUnmoved<T> {
22        #[pin]
23        inner: T,
24        this_addr: usize,
25        first_pinned_mutably_accessed_at: Option<&'static Location<'static>>,
26    }
27    impl<T> PinnedDrop for AssertUnmoved<T> {
28        /// # Panics
29        ///
30        /// Panics if this `AssertUnmoved` moved after being pinned and mutably accessed.
31        fn drop(this: Pin<&mut Self>) {
32            // If the thread is panicking then we can't panic again as that will
33            // cause the process to be aborted.
34            if !thread::panicking() && this.this_addr != 0 {
35                let cur_this = this.addr();
36                assert_eq!(
37                    this.this_addr,
38                    cur_this,
39                    "AssertUnmoved moved before drop\n\
40                     \tfirst pinned mutably accessed at {}\n",
41                    this.first_pinned_mutably_accessed_at.unwrap()
42                );
43            }
44        }
45    }
46}
47
48impl<T> AssertUnmoved<T> {
49    /// Creates a new `AssertUnmoved`.
50    #[must_use]
51    pub const fn new(inner: T) -> Self {
52        Self { inner, this_addr: 0, first_pinned_mutably_accessed_at: None }
53    }
54
55    /// Gets a reference to the underlying type.
56    ///
57    /// Unlike [`get_mut`](AssertUnmoved::get_mut) method, this method can always called.
58    ///
59    /// You can also access the underlying type via [`Deref`](std::ops::Deref) impl.
60    #[must_use]
61    pub const fn get_ref(&self) -> &T {
62        &self.inner
63    }
64
65    /// Gets a mutable reference to the underlying type.
66    ///
67    /// Note that this method can only be called before pinned since
68    /// `AssertUnmoved` is `!Unpin` (this is guaranteed by the type system!).
69    ///
70    /// # Panics
71    ///
72    /// Panics if this `AssertUnmoved` moved after being pinned and mutably accessed.
73    #[track_caller]
74    #[must_use]
75    pub fn get_mut(&mut self) -> &mut T {
76        if self.this_addr != 0 {
77            let cur_this = self.addr();
78            assert_eq!(
79                self.this_addr,
80                cur_this,
81                "AssertUnmoved moved after get_pin_mut call\n\
82                 \tfirst pinned mutably accessed at {}\n",
83                self.first_pinned_mutably_accessed_at.unwrap()
84            );
85        }
86        &mut self.inner
87    }
88
89    /// Gets a pinned mutable reference to the underlying type.
90    ///
91    /// # Panics
92    ///
93    /// Panics if this `AssertUnmoved` moved after being pinned and mutably accessed.
94    ///
95    /// # Examples
96    ///
97    /// Implement own [`Stream`] trait for `AssertUnmoved`.
98    ///
99    /// ```
100    /// use std::{
101    ///     pin::Pin,
102    ///     task::{Context, Poll},
103    /// };
104    ///
105    /// use assert_unmoved::AssertUnmoved;
106    ///
107    /// pub trait MyStream {
108    ///     type Item;
109    ///
110    ///     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>;
111    /// }
112    ///
113    /// impl<S: MyStream> MyStream for AssertUnmoved<S> {
114    ///     type Item = S::Item;
115    ///
116    ///     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
117    ///         self.get_pin_mut().poll_next(cx)
118    ///     }
119    /// }
120    /// ```
121    ///
122    /// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html
123    #[track_caller]
124    #[must_use]
125    pub fn get_pin_mut(mut self: Pin<&mut Self>) -> Pin<&mut T> {
126        let cur_this = self.addr();
127        if self.this_addr == 0 {
128            // First time being pinned and mutably accessed.
129            *self.as_mut().project().this_addr = cur_this;
130            *self.as_mut().project().first_pinned_mutably_accessed_at = Some(Location::caller());
131        } else {
132            assert_eq!(
133                self.this_addr,
134                cur_this,
135                "AssertUnmoved moved between get_pin_mut calls\n\
136                 \tfirst pinned mutably accessed at {}\n",
137                self.first_pinned_mutably_accessed_at.unwrap()
138            );
139        }
140        self.project().inner
141    }
142
143    fn addr(&self) -> usize {
144        self as *const Self as usize
145    }
146}
147
148impl<T> ops::Deref for AssertUnmoved<T> {
149    type Target = T;
150
151    fn deref(&self) -> &Self::Target {
152        self.get_ref()
153    }
154}
155
156impl<T> From<T> for AssertUnmoved<T> {
157    /// Converts a `T` into a `AssertUnmoved<T>`.
158    ///
159    /// This is equivalent to [`AssertUnmoved::new`].
160    fn from(inner: T) -> Self {
161        Self::new(inner)
162    }
163}
164
165impl<T: Default> Default for AssertUnmoved<T> {
166    /// Creates a new `AssertUnmoved`, with the default value for `T`.
167    ///
168    /// This is equivalent to [`AssertUnmoved::new(T::default())`](AssertUnmoved::new).
169    fn default() -> Self {
170        Self::new(T::default())
171    }
172}
173
174impl<F: Future> Future for AssertUnmoved<F> {
175    type Output = F::Output;
176
177    #[track_caller]
178    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179        self.get_pin_mut().poll(cx)
180    }
181}
182
183#[cfg(feature = "futures03")]
184#[cfg_attr(docsrs, doc(cfg(feature = "futures03")))]
185mod futures03 {
186    use core::{
187        pin::Pin,
188        task::{Context, Poll},
189    };
190
191    use futures_core::{
192        future::FusedFuture,
193        stream::{FusedStream, Stream},
194    };
195    use futures_io as io;
196    use futures_sink::Sink;
197
198    use super::AssertUnmoved;
199
200    impl<F: FusedFuture> FusedFuture for AssertUnmoved<F> {
201        fn is_terminated(&self) -> bool {
202            self.get_ref().is_terminated()
203        }
204    }
205
206    impl<S: Stream> Stream for AssertUnmoved<S> {
207        type Item = S::Item;
208
209        #[track_caller]
210        fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
211            self.get_pin_mut().poll_next(cx)
212        }
213
214        fn size_hint(&self) -> (usize, Option<usize>) {
215            self.get_ref().size_hint()
216        }
217    }
218
219    impl<S: FusedStream> FusedStream for AssertUnmoved<S> {
220        fn is_terminated(&self) -> bool {
221            self.get_ref().is_terminated()
222        }
223    }
224
225    impl<S: Sink<Item>, Item> Sink<Item> for AssertUnmoved<S> {
226        type Error = S::Error;
227
228        #[track_caller]
229        fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
230            self.get_pin_mut().poll_ready(cx)
231        }
232
233        #[track_caller]
234        fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
235            self.get_pin_mut().start_send(item)
236        }
237
238        #[track_caller]
239        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
240            self.get_pin_mut().poll_flush(cx)
241        }
242
243        #[track_caller]
244        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
245            self.get_pin_mut().poll_close(cx)
246        }
247    }
248
249    impl<R: io::AsyncRead> io::AsyncRead for AssertUnmoved<R> {
250        #[track_caller]
251        fn poll_read(
252            self: Pin<&mut Self>,
253            cx: &mut Context<'_>,
254            buf: &mut [u8],
255        ) -> Poll<io::Result<usize>> {
256            self.get_pin_mut().poll_read(cx, buf)
257        }
258
259        #[track_caller]
260        fn poll_read_vectored(
261            self: Pin<&mut Self>,
262            cx: &mut Context<'_>,
263            bufs: &mut [io::IoSliceMut<'_>],
264        ) -> Poll<io::Result<usize>> {
265            self.get_pin_mut().poll_read_vectored(cx, bufs)
266        }
267    }
268
269    impl<W: io::AsyncWrite> io::AsyncWrite for AssertUnmoved<W> {
270        #[track_caller]
271        fn poll_write(
272            self: Pin<&mut Self>,
273            cx: &mut Context<'_>,
274            buf: &[u8],
275        ) -> Poll<io::Result<usize>> {
276            self.get_pin_mut().poll_write(cx, buf)
277        }
278
279        #[track_caller]
280        fn poll_write_vectored(
281            self: Pin<&mut Self>,
282            cx: &mut Context<'_>,
283            bufs: &[io::IoSlice<'_>],
284        ) -> Poll<io::Result<usize>> {
285            self.get_pin_mut().poll_write_vectored(cx, bufs)
286        }
287
288        #[track_caller]
289        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
290            self.get_pin_mut().poll_flush(cx)
291        }
292
293        #[track_caller]
294        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
295            self.get_pin_mut().poll_close(cx)
296        }
297    }
298
299    impl<S: io::AsyncSeek> io::AsyncSeek for AssertUnmoved<S> {
300        #[track_caller]
301        fn poll_seek(
302            self: Pin<&mut Self>,
303            cx: &mut Context<'_>,
304            pos: io::SeekFrom,
305        ) -> Poll<io::Result<u64>> {
306            self.get_pin_mut().poll_seek(cx, pos)
307        }
308    }
309
310    impl<R: io::AsyncBufRead> io::AsyncBufRead for AssertUnmoved<R> {
311        #[track_caller]
312        fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
313            self.get_pin_mut().poll_fill_buf(cx)
314        }
315
316        #[track_caller]
317        fn consume(self: Pin<&mut Self>, amt: usize) {
318            self.get_pin_mut().consume(amt);
319        }
320    }
321}
322
323#[cfg(feature = "tokio02")]
324#[cfg_attr(docsrs, doc(cfg(feature = "tokio02")))]
325mod tokio02 {
326    use core::{
327        mem::MaybeUninit,
328        pin::Pin,
329        task::{Context, Poll},
330    };
331    use std::io;
332
333    use bytes05::{Buf, BufMut};
334    use tokio02_crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite};
335
336    use super::AssertUnmoved;
337
338    impl<R: AsyncRead> AsyncRead for AssertUnmoved<R> {
339        unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
340            // SAFETY: The safety contract must be upheld by the caller.
341            unsafe { self.get_ref().prepare_uninitialized_buffer(buf) }
342        }
343
344        #[track_caller]
345        fn poll_read(
346            self: Pin<&mut Self>,
347            cx: &mut Context<'_>,
348            buf: &mut [u8],
349        ) -> Poll<io::Result<usize>> {
350            self.get_pin_mut().poll_read(cx, buf)
351        }
352
353        #[track_caller]
354        fn poll_read_buf<B: BufMut>(
355            self: Pin<&mut Self>,
356            cx: &mut Context<'_>,
357            buf: &mut B,
358        ) -> Poll<io::Result<usize>>
359        where
360            Self: Sized,
361        {
362            self.get_pin_mut().poll_read_buf(cx, buf)
363        }
364    }
365
366    impl<W: AsyncWrite> AsyncWrite for AssertUnmoved<W> {
367        #[track_caller]
368        fn poll_write(
369            self: Pin<&mut Self>,
370            cx: &mut Context<'_>,
371            buf: &[u8],
372        ) -> Poll<io::Result<usize>> {
373            self.get_pin_mut().poll_write(cx, buf)
374        }
375
376        #[track_caller]
377        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
378            self.get_pin_mut().poll_flush(cx)
379        }
380
381        #[track_caller]
382        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
383            self.get_pin_mut().poll_shutdown(cx)
384        }
385
386        #[track_caller]
387        fn poll_write_buf<B: Buf>(
388            self: Pin<&mut Self>,
389            cx: &mut Context<'_>,
390            buf: &mut B,
391        ) -> Poll<Result<usize, io::Error>>
392        where
393            Self: Sized,
394        {
395            self.get_pin_mut().poll_write_buf(cx, buf)
396        }
397    }
398
399    impl<S: AsyncSeek> AsyncSeek for AssertUnmoved<S> {
400        #[track_caller]
401        fn start_seek(
402            self: Pin<&mut Self>,
403            cx: &mut Context<'_>,
404            pos: io::SeekFrom,
405        ) -> Poll<io::Result<()>> {
406            self.get_pin_mut().start_seek(cx, pos)
407        }
408
409        #[track_caller]
410        fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
411            self.get_pin_mut().poll_complete(cx)
412        }
413    }
414
415    impl<R: AsyncBufRead> AsyncBufRead for AssertUnmoved<R> {
416        #[track_caller]
417        fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
418            self.get_pin_mut().poll_fill_buf(cx)
419        }
420
421        #[track_caller]
422        fn consume(self: Pin<&mut Self>, amt: usize) {
423            self.get_pin_mut().consume(amt);
424        }
425    }
426}
427
428#[cfg(feature = "tokio03")]
429#[cfg_attr(docsrs, doc(cfg(feature = "tokio03")))]
430mod tokio03 {
431    use core::{
432        pin::Pin,
433        task::{Context, Poll},
434    };
435
436    use tokio03_crate::io;
437
438    use super::AssertUnmoved;
439
440    impl<R: io::AsyncRead> io::AsyncRead for AssertUnmoved<R> {
441        #[track_caller]
442        fn poll_read(
443            self: Pin<&mut Self>,
444            cx: &mut Context<'_>,
445            buf: &mut io::ReadBuf<'_>,
446        ) -> Poll<io::Result<()>> {
447            self.get_pin_mut().poll_read(cx, buf)
448        }
449    }
450
451    impl<W: io::AsyncWrite> io::AsyncWrite for AssertUnmoved<W> {
452        #[track_caller]
453        fn poll_write(
454            self: Pin<&mut Self>,
455            cx: &mut Context<'_>,
456            buf: &[u8],
457        ) -> Poll<io::Result<usize>> {
458            self.get_pin_mut().poll_write(cx, buf)
459        }
460
461        #[track_caller]
462        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
463            self.get_pin_mut().poll_flush(cx)
464        }
465
466        #[track_caller]
467        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
468            self.get_pin_mut().poll_shutdown(cx)
469        }
470    }
471
472    impl<S: io::AsyncSeek> io::AsyncSeek for AssertUnmoved<S> {
473        #[track_caller]
474        fn start_seek(self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> {
475            self.get_pin_mut().start_seek(pos)
476        }
477
478        #[track_caller]
479        fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
480            self.get_pin_mut().poll_complete(cx)
481        }
482    }
483
484    impl<R: io::AsyncBufRead> io::AsyncBufRead for AssertUnmoved<R> {
485        #[track_caller]
486        fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
487            self.get_pin_mut().poll_fill_buf(cx)
488        }
489
490        #[track_caller]
491        fn consume(self: Pin<&mut Self>, amt: usize) {
492            self.get_pin_mut().consume(amt);
493        }
494    }
495}
496
497#[cfg(feature = "tokio1")]
498#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
499mod tokio1 {
500    use core::{
501        pin::Pin,
502        task::{Context, Poll},
503    };
504
505    use tokio1_crate::io;
506
507    use super::AssertUnmoved;
508
509    impl<R: io::AsyncRead> io::AsyncRead for AssertUnmoved<R> {
510        #[track_caller]
511        fn poll_read(
512            self: Pin<&mut Self>,
513            cx: &mut Context<'_>,
514            buf: &mut io::ReadBuf<'_>,
515        ) -> Poll<io::Result<()>> {
516            self.get_pin_mut().poll_read(cx, buf)
517        }
518    }
519
520    impl<W: io::AsyncWrite> io::AsyncWrite for AssertUnmoved<W> {
521        #[track_caller]
522        fn poll_write(
523            self: Pin<&mut Self>,
524            cx: &mut Context<'_>,
525            buf: &[u8],
526        ) -> Poll<io::Result<usize>> {
527            self.get_pin_mut().poll_write(cx, buf)
528        }
529
530        #[track_caller]
531        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
532            self.get_pin_mut().poll_flush(cx)
533        }
534
535        #[track_caller]
536        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
537            self.get_pin_mut().poll_shutdown(cx)
538        }
539
540        #[track_caller]
541        fn poll_write_vectored(
542            self: Pin<&mut Self>,
543            cx: &mut Context<'_>,
544            bufs: &[std::io::IoSlice<'_>],
545        ) -> Poll<Result<usize, io::Error>> {
546            self.get_pin_mut().poll_write_vectored(cx, bufs)
547        }
548
549        fn is_write_vectored(&self) -> bool {
550            self.get_ref().is_write_vectored()
551        }
552    }
553
554    impl<S: io::AsyncSeek> io::AsyncSeek for AssertUnmoved<S> {
555        #[track_caller]
556        fn start_seek(self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> {
557            self.get_pin_mut().start_seek(pos)
558        }
559
560        #[track_caller]
561        fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
562            self.get_pin_mut().poll_complete(cx)
563        }
564    }
565
566    impl<R: io::AsyncBufRead> io::AsyncBufRead for AssertUnmoved<R> {
567        #[track_caller]
568        fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
569            self.get_pin_mut().poll_fill_buf(cx)
570        }
571
572        #[track_caller]
573        fn consume(self: Pin<&mut Self>, amt: usize) {
574            self.get_pin_mut().consume(amt);
575        }
576    }
577}