extend_mut/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
#![cfg_attr(not(test), no_std)]

/*!

This crate provides a safe way to extend the lifetime of a exclusive reference.

[`extend_mut`] allows for safe extension of the lifetime of a exclusive reference
with a blocking closure.

[`extend_mut_async`] is similar to [`extend_mut`], but it is async and requires
a linear type be safe - but Rust does not have linear types yet, so it is unsafe.

*/

use core::{
    future::Future,
    pin::Pin,
    ptr,
    task::{Context, Poll},
};

// With `panic=abort` it will directly go to panic handler without unwind.
// With `panic=unwind` it will painc-in-drop, which will cause panic_nounwind.
fn abort_no_unwind(msg: &'static str) -> ! {
    struct DoublePanic(&'static str);

    impl Drop for DoublePanic {
        fn drop(&mut self) {
            panic!("{}", self.0);
        }
    }

    let _double_panic = DoublePanic(msg);
    panic!("{msg}");
}

// SAFETY:
//     if `'a` is >= `'b`, then is is safe by [extend_mut_proof_for_smaller] proof.
//     if `f` will diverge, `'a` will be `'static`, which is valid.
//     if `f` will return `&'b mut T` back, then `'a` will be large enough to fit this call.
//         That way, `&'b mut T` will not exist for `'b`, but only for `'a`.
//
//     if `f` stored `&'b mut T`, then
//         if `f` diverged, it is fine, because `'a` becomes `'static`.
//         else `f` must return `&'b mut T` different from the one it stored.
//           we verify it by an assertion.
//     else we know that `f` did not store the reference we gave it, so it is sound.

/// Extends the lifetime of a mutable reference. Note that `f` must return the same reference
/// that was passed to it, otherwise it will abort the process.
pub fn extend_mut<'a, 'b, T: 'b, R>(
    mut_ref: &'a mut T,
    f: impl FnOnce(&'b mut T) -> (&'b mut T, R),
) -> R {
    let ptr = ptr::from_mut(mut_ref);
    let (extended, next) = f(unsafe { &mut *ptr });
    if ptr != ptr::from_mut(extended) {
        abort_no_unwind("ExtendMut: Pointer changed");
    }

    next
}

pin_project_lite::pin_project! {
    /// Future returned by returned by [extend_mut_async].
    /// Consult it's documentation for more information and safety requirements.
    pub struct ExtendMutFuture<'a, T, Fut, R> {
        ptr: *mut T,
        marker: core::marker::PhantomData<(&'a mut T, R)>,
        #[pin]
        future: Fut,
        // Instead of having that bool, we might make `ptr` null.
        ready: bool,
    }

    impl<'a, T, Fut, R> PinnedDrop for ExtendMutFuture<'a, T, Fut, R> {
        fn drop(this: Pin<&mut Self>) {
            if !*this.project().ready {
                abort_no_unwind("Cannot drop ExtendMutFuture before it yields Poll::Ready");
            }
        }
    }
}

impl<'a, T, Fut, R> Future for ExtendMutFuture<'a, T, Fut, R>
where
    Fut: Future<Output = (&'a mut T, R)>,
{
    type Output = R;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();
        let ptr = *this.ptr;

        if *this.ready {
            return Poll::Pending;
        }

        match this.future.poll(cx) {
            Poll::Ready((extended, ret)) => {
                if ptr == ptr::from_mut(extended) {
                    *this.ready = true;
                    Poll::Ready(ret)
                } else {
                    abort_no_unwind("ExtendMut: Pointer changed")
                }
            }
            Poll::Pending => Poll::Pending,
        }
    }
}

/// Async version of [`extend_mut`]. You should not drop the future returned by [`extend_mut_async`]
/// until it yields [`Poll::Ready`] - if you do, it will abort the process. This function is *not*
/// cancel-safe.
///
/// If polled after yielding [`Poll::Ready`], it will always return [`Poll::Pending`].
///
/// # Safety
///
/// Shortly - do not cancel returned future.
///
/// You must not skip abortion on dropping the future returned by [`extend_mut_async`]
/// by any means, including [forget](core::mem::forget), [`ManuallyDrop`](core::mem::ManuallyDrop) etc. Otherwise,
/// borrow checker will allow you to use `mut_ref` while it might be used by `f`, which will
/// be undefined behavior.
pub unsafe fn extend_mut_async<'a, 'b, T: 'b, F, Fut, R>(
    mut_ref: &'a mut T,
    f: F,
) -> ExtendMutFuture<'b, T, Fut, R>
where
    Fut: Future<Output = (&'b mut T, R)>,
    F: FnOnce(&'b mut T) -> Fut,
{
    let ptr = ptr::from_mut(mut_ref);
    let future = f(unsafe { &mut *ptr });

    ExtendMutFuture {
        ptr,
        marker: core::marker::PhantomData,
        future,
        ready: false,
    }
}

#[allow(dead_code)]
fn extend_mut_proof_for_smaller<'a: 'b, 'b, T: 'b, R>(
    mut_ref: &'a mut T,
    f: impl FnOnce(&'b mut T) -> (&'b mut T, R),
) -> R {
    f(mut_ref).1
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_extend_mut() {
        let mut x = 5;

        fn want_static(x: &'static mut i32) -> &'static mut i32 {
            assert_eq!(*x, 5);
            *x += 1;
            *x += 1;
            x
        }

        let r = extend_mut(&mut x, |x| (want_static(x), 6));
        assert_eq!(r, 6);
        assert_eq!(x, 7);
    }

    #[test]
    fn test_extend_mut_async_immediate() {
        use core::pin::pin;
        use core::task::{Context, Poll, Waker};

        let mut x = 5;
        async fn want_static(x: &'static mut i32) -> &'static mut i32 {
            assert_eq!(*x, 5);
            x
        }

        let fut = unsafe { extend_mut_async(&mut x, async |x| (want_static(x).await, 8)) };
        let mut fut = pin!(fut);
        let ret = loop {
            match fut.as_mut().poll(&mut Context::from_waker(&Waker::noop())) {
                Poll::Ready(ret) => break ret,
                Poll::Pending => panic!(),
            }
        };

        assert_eq!(ret, 8);
    }

    #[test]
    fn test_extend_mut_async_yielding() {
        use core::pin::pin;
        use core::task::{Context, Poll, Waker};

        let mut x = 5;

        async fn want_static(x: &'static mut i32) -> &'static mut i32 {
            let mut i = 0;

            let yield_fn = core::future::poll_fn(|cx| {
                *x += 1;

                if i == 20 {
                    return Poll::Ready(());
                } else {
                    i += 1;
                    cx.waker().wake_by_ref();
                    return Poll::Pending;
                }
            });

            yield_fn.await;

            x
        }

        let fut = unsafe { extend_mut_async(&mut x, async |x| (want_static(x).await, 8)) };
        let mut fut = pin!(fut);
        let ret = loop {
            match fut.as_mut().poll(&mut Context::from_waker(&Waker::noop())) {
                Poll::Ready(ret) => break ret,
                Poll::Pending => continue,
            }
        };

        assert_eq!(ret, 8);
        assert_eq!(x, 26);
    }
}