#![cfg_attr(all(not(test), not(feature = "std")), no_std)]
#![recursion_limit = "512"]
#![cfg_attr(feature = "async", feature(async_fn_traits))]
use core::ptr;
#[cfg(feature = "async")]
use core::{
future::Future,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use aborts::{abort_no_unwind, abort_on_unwind};
mod aborts;
mod impls;
pub trait ExtendMut<'b>: Sized {
type Extended;
fn extend_mut<R, ER: IntoExtendMutReturn<Self::Extended, R>>(
self,
f: impl FnOnce(Self::Extended) -> ER,
) -> R;
#[cfg(feature = "assume-non-forget")]
fn extend_mut_async<R, ER: IntoExtendMutReturn<Self::Extended, R>>(
self,
f: impl AsyncFnOnce(Self::Extended) -> ER,
) -> impl Future<Output = R>;
}
pub unsafe trait IntoExtendMutReturn<T, R> {
fn into_extend_mut_return(self) -> (T, R);
}
#[allow(dead_code)]
fn extend_mut_proof_for_smaller<'a: 'b, 'b, T: 'b, R>(
mut_ref: &'a mut T,
f: impl FnOnce(&'b mut T) -> (&'b mut T, R),
) -> R {
f(mut_ref).1
}
#[inline(always)]
pub fn extend_mut<'a, 'b, T: ?Sized + 'b, F, R, ExtR>(mut_ref: &'a mut T, f: F) -> R
where
F: FnOnce(&'b mut T) -> ExtR,
ExtR: IntoExtendMutReturn<&'b mut T, R>,
{
assert!(size_of_val::<T>(&*mut_ref) != 0);
let ptr = ptr::from_mut(mut_ref);
let ret = abort_on_unwind(
#[inline(always)]
move || f(unsafe { &mut *ptr }),
);
let (extended, next) = ret.into_extend_mut_return();
if !core::ptr::eq(ptr, ptr::from_mut(extended)) {
abort_no_unwind("ExtendMut: Pointer changed");
}
next
}
#[cfg(feature = "async")]
pin_project_lite::pin_project! {
pub struct ExtendMutFuture<'a, 'b, T: ?Sized, Fut, R, ExtR> {
ptr: *mut T,
marker: PhantomData<(&'a mut T, &'b mut T, R, ExtR)>,
#[pin]
future: Fut,
ready: bool,
}
impl<'a, 'b, T: ?Sized, Fut, R, ExtR> PinnedDrop for ExtendMutFuture<'a, 'b, T, Fut, R, ExtR> {
fn drop(this: Pin<&mut Self>) {
if !*this.project().ready {
abort_no_unwind("Cannot drop ExtendMutFuture before it yields Poll::Ready");
}
}
}
}
#[cfg(feature = "async")]
impl<'a, 'b, T, Fut, R, ExdR> Future for ExtendMutFuture<'a, 'b, T, Fut, R, ExdR>
where
T: ?Sized,
ExdR: IntoExtendMutReturn<&'b mut T, R>,
Fut: Future<Output = ExdR>,
{
type Output = R;
#[inline(always)]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let ptr = *this.ptr;
if *this.ready {
return Poll::Pending;
}
match abort_on_unwind(
#[inline(always)]
move || this.future.poll(cx),
) {
Poll::Ready(ret) => {
let (extended, ret) = ret.into_extend_mut_return();
if core::ptr::eq(ptr, ptr::from_mut(extended)) {
*this.ready = true;
Poll::Ready(ret)
} else {
abort_no_unwind("ExtendMut: Pointer changed")
}
}
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(feature = "async")]
#[cfg(not(feature = "assume-non-forget"))]
pub unsafe fn extend_mut_async<'a, 'b, T: 'b, F, R, ExdR>(
mut_ref: &'a mut T,
f: F,
) -> ExtendMutFuture<'a, 'b, T, F::CallOnceFuture, R, ExdR>
where
ExdR: IntoExtendMutReturn<&'b mut T, R>,
F: AsyncFnOnce(&'b mut T) -> ExdR,
{
unsafe { extend_mut_async_inner(mut_ref, f) }
}
#[cfg(feature = "async")]
#[cfg(feature = "assume-non-forget")]
pub fn extend_mut_async<'a, 'b, T: ?Sized + 'b, F, R, ExdR>(
mut_ref: &'a mut T,
f: F,
) -> ExtendMutFuture<'a, 'b, T, F::CallOnceFuture, R, ExdR>
where
ExdR: IntoExtendMutReturn<&'b mut T, R>,
F: AsyncFnOnce(&'b mut T) -> ExdR,
{
unsafe { extend_mut_async_inner(mut_ref, f) }
}
#[cfg(feature = "async")]
unsafe fn extend_mut_async_inner<'a, 'b, T: ?Sized + 'b, F, R, ExdR>(
mut_ref: &'a mut T,
f: F,
) -> ExtendMutFuture<'a, 'b, T, F::CallOnceFuture, R, ExdR>
where
ExdR: IntoExtendMutReturn<&'b mut T, R>,
F: AsyncFnOnce(&'b mut T) -> ExdR,
{
assert!(size_of_val::<T>(&*mut_ref) != 0);
let ptr = ptr::from_mut(mut_ref);
let future = f(unsafe { &mut *ptr });
ExtendMutFuture {
ptr,
marker: PhantomData,
future,
ready: false,
}
}
#[cfg(test)]
mod test {
use super::*;
#[rustfmt::skip]
#[test]
fn test_sync_api() {
let (mut t1, mut t2, mut t3, mut t4) = (1, 2, 3, 4);
let () = extend_mut(&mut t1, |t1: &'static mut u8| t1);
let "hi" = extend_mut(&mut t1, |t1| (t1, "hi")) else { panic!() };
let () = extend_mut(&mut t1, |t1| (t1, ()));
let () = t1.extend_mut(|t1| t1);
let () = (&mut t1).extend_mut(|t1: &'static mut u8| t1);
let "hi" = t1.extend_mut(|t1| (t1, "hi")) else { panic!() };
let "hi" = (&mut t1).extend_mut(|t1| (t1, "hi")) else { panic!() };
let () = (t1, t2).extend_mut(|it: &'static mut (u8, u8)| it);
let () = (&mut t1, &mut t2).extend_mut(|it: (&'static mut u8, &'static mut u8)| it);
let () = (&mut (t1, t2)).extend_mut(|it: &'static mut (u8, u8)| it);
let "hi" = (t1, t2).extend_mut(|it| (it, "hi")) else { panic!() };
let "hi" = (&mut t1, &mut t2).extend_mut(|it| (it, "hi")) else { panic!() };
let "hi" = (&mut (t1, t2)).extend_mut(|it| (it, "hi")) else { panic!() };
let () = (t1, t2, t3).extend_mut(|it: &'static mut (u8, u8, u8)| it);
let () = (&mut t1, &mut t2, &mut t3).extend_mut(|it: (&'static mut u8, &'static mut u8, &mut u8)| it);
let "hi" = (t1, t2, t3).extend_mut(|it| (it, "hi")) else { panic!() };
let "hi" = (&mut t1, &mut t2, &mut t3).extend_mut(|it| (it, "hi")) else { panic!() };
let () = (t1, t2, t3, t4).extend_mut(|it: &'static mut (u8, u8, u8, u8)| it);
let () = (&mut t1, &mut t2, &mut t3, &mut t4).extend_mut(|it: (&mut u8, &mut u8, &mut u8, &mut u8)| it);
let "hi" = (t1, t2, t3, t4).extend_mut(|it| (it, "hi")) else { panic!() };
let "hi" = (&mut t1, &mut t2, &mut t3, &mut t4).extend_mut(|it| (it, "hi")) else { panic!() };
let () = <_>::extend_mut(&mut (t1, t2, t3, t4), |it: &'static mut (u8, u8, u8, u8)| it);
let () = <_>::extend_mut((&mut t1, &mut t2, &mut t3, &mut t4), |it| it);
}
#[test]
fn test_extend_mut() {
let mut x = 5;
fn want_static(x: &'static mut i32) -> &'static mut i32 {
*x += 1;
*x += 1;
x
}
extend_mut(&mut x, |x| want_static(x));
assert_eq!(x, 7);
let hi = x.extend_mut(|x| (want_static(x), "hi"));
assert_eq!(hi, "hi");
assert_eq!(x, 9);
x.extend_mut(want_static);
assert_eq!(x, 11);
let mut y = 7;
let mut z = 7;
let hi = <_>::extend_mut((&mut x, &mut y, &mut z), |(x, y, z)| {
((want_static(x), y, z), "hi")
});
assert_eq!(hi, "hi");
}
#[test]
#[cfg(feature = "async")]
fn test_extend_mut_async_immediate() {
use core::pin::pin;
use core::task::{Context, Poll, Waker};
let mut x = 5;
async fn want_static(x: &'static mut i32) -> &'static mut i32 {
assert_eq!(*x, 5);
x
}
let fut = unsafe { extend_mut_async(&mut x, async |x| (want_static(x).await, 8)) };
let mut fut = pin!(fut);
let ret = loop {
match fut.as_mut().poll(&mut Context::from_waker(&Waker::noop())) {
Poll::Ready(ret) => break ret,
Poll::Pending => panic!(),
}
};
assert_eq!(ret, 8);
}
#[test]
#[cfg(feature = "async")]
fn test_extend_mut_async_yielding() {
use core::pin::pin;
use core::task::{Context, Poll, Waker};
let mut x = 5;
async fn want_static(x: &'static mut i32) -> &'static mut i32 {
let mut i = 0;
let yield_fn = core::future::poll_fn(|cx| {
*x += 1;
if i == 20 {
return Poll::Ready(());
} else {
i += 1;
cx.waker().wake_by_ref();
return Poll::Pending;
}
});
yield_fn.await;
x
}
{
let fut = unsafe { extend_mut_async(&mut x, async |x| want_static(x).await) };
let mut fut = pin!(fut);
() = loop {
match fut.as_mut().poll(&mut Context::from_waker(&Waker::noop())) {
Poll::Ready(ret) => break ret,
Poll::Pending => continue,
}
};
}
assert_eq!(x, 26);
}
}