use core::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Rem, Shl, Shr, Sub};
macro_rules! doc {
($( $x:expr, )* @$item:item) => {
$( #[doc = $x] )*
$item
};
}
macro_rules! def_binary {
($Op:ident, $op:ident, $RefOp:ident, $ref_op:ident) => {
mod $op {
pub trait Sealed<Rhs = Self> {}
}
doc!(
concat!("`", stringify!($op), "` operation through mutable references."),
"",
"As of Rust 1.73.0, the following code does not compile:",
"```compile_fail",
concat!("use core::ops::", stringify!($Op), ";"),
"",
"struct A<T>(T);",
"",
concat!("impl<'a, 'b, T, U> ", stringify!($Op), "<&'b mut A<U>> for &'a mut A<T>"),
"where",
concat!(" &'a mut T: ", stringify!($Op), "<&'b mut U>,"),
"{",
concat!(" type Output = A<<&'a mut T as ", stringify!($Op), "<&'b mut U>>::Output>;"),
"",
concat!(" fn ", stringify!($op), "(self, rhs: &'b mut A<U>) -> Self::Output {"),
concat!(" A(self.0.", stringify!($op), "(&mut rhs.0))"),
" }",
"}",
"",
"fn _f<T, U>(mut a: T, mut b: U)",
"where",
concat!(" for<'a, 'b> &'a mut T: ", stringify!($Op), "<&'b mut U>,"),
"{",
concat!(" let _a_op_b = (&mut a).", stringify!($op), "(&mut b);"),
"",
concat!(" // to do something with `a`, `b`, and `_a_op_b`"),
"}",
"",
"fn _g<T, U>(a: T, b: U)",
"where",
concat!(" for<'a, 'b> &'a mut T: ", stringify!($Op), "<&'b mut U>,"),
"{",
" _f(a, b);",
"}",
"```",
"but the following code does:",
"```",
concat!("use core::ops::", stringify!($Op), ";"),
concat!("use ref_ops::", stringify!($RefOp),";"),
"",
"struct A<T>(T);",
"",
concat!("impl<'a, T, U> ", stringify!($Op), "<&'a mut A<U>> for &mut A<T>"),
"where",
concat!(" T: ", stringify!($RefOp), "<&'a mut U>,"),
"{",
" type Output = A<T::Output>;",
"",
concat!(" fn ", stringify!($op), "(self, rhs: &'a mut A<U>) -> Self::Output {"),
concat!(" A(self.0.", stringify!($ref_op), "(&mut rhs.0))"),
" }",
"}",
"",
"fn _f<T, U>(mut a: T, mut b: U)",
"where",
concat!(" for<'a, 'b> &'a mut T: ", stringify!($Op), "<&'b mut U>,"),
"{",
concat!(" let _a_op_b = (&mut a).", stringify!($op), "(&mut b);"),
"",
concat!(" // to do something with `a`, `b`, and `_a_op_b`"),
"}",
"",
"fn _g<T, U>(a: T, b: U)",
"where",
concat!(" for<'a, 'b> &'a mut T: ", stringify!($Op), "<&'b mut U>,"),
"{",
" _f(a, b);",
"}",
"```",
@pub trait $RefOp<Rhs = Self>: $op::Sealed<Rhs> {
doc!(
concat!("The resulting type after applying `", stringify!($op), "` operation."),
@type Output;
);
doc!(
concat!("Performs `", stringify!($op), "` operation."),
@fn $ref_op(&mut self, rhs: Rhs) -> Self::Output;
);
}
);
impl<T, Rhs, O> $op::Sealed<Rhs> for T
where
T: ?Sized,
for<'a> &'a mut T: $Op<Rhs, Output = O>,
{
}
impl<T, Rhs, O> $RefOp<Rhs> for T
where
T: ?Sized,
for<'a> &'a mut T: $Op<Rhs, Output = O>,
{
type Output = O;
fn $ref_op(&mut self, rhs: Rhs) -> O {
self.$op(rhs)
}
}
};
}
def_binary!(Add, add, RefMutAdd, ref_mut_add);
def_binary!(Sub, sub, RefMutSub, ref_mut_sub);
def_binary!(Mul, mul, RefMutMul, ref_mut_mul);
def_binary!(Div, div, RefMutDiv, ref_mut_div);
def_binary!(Rem, rem, RefMutRem, ref_mut_rem);
def_binary!(Shl, shl, RefMutShl, ref_mut_shl);
def_binary!(Shr, shr, RefMutShr, ref_mut_shr);
def_binary!(BitAnd, bitand, RefMutBitAnd, ref_mut_bitand);
def_binary!(BitOr, bitor, RefMutBitOr, ref_mut_bitor);
def_binary!(BitXor, bitxor, RefMutBitXor, ref_mut_bitxor);
#[cfg(test)]
mod tests {
use super::*;
use crate::{
RefAdd, RefBitAnd, RefBitOr, RefBitXor, RefDiv, RefMul, RefRem, RefShl, RefShr, RefSub,
};
#[derive(PartialEq)]
struct B<T>(T);
macro_rules! impl_binary {
($Op:ident, $op:ident, $RefOp:ident, $ref_op:ident) => {
impl<'a, T, U> $Op<&'a mut B<U>> for &mut B<T>
where
T: $RefOp<&'a U>,
{
type Output = B<T::Output>;
fn $op(self, rhs: &'a mut B<U>) -> Self::Output {
B(self.0.$ref_op(&rhs.0))
}
}
};
}
impl_binary!(Add, add, RefAdd, ref_add);
impl_binary!(Sub, sub, RefSub, ref_sub);
impl_binary!(Mul, mul, RefMul, ref_mul);
impl_binary!(Div, div, RefDiv, ref_div);
impl_binary!(Rem, rem, RefRem, ref_rem);
impl_binary!(Shl, shl, RefShl, ref_shl);
impl_binary!(Shr, shr, RefShr, ref_shr);
impl_binary!(BitAnd, bitand, RefBitAnd, ref_bitand);
impl_binary!(BitOr, bitor, RefBitOr, ref_bitor);
impl_binary!(BitXor, bitxor, RefBitXor, ref_bitxor);
macro_rules! test_binary {
($fn:ident, $Op:ident, $op:ident, $RefOp:ident, $ref_op:ident, $assert:expr, $dummy:expr) => {
#[test]
fn $fn() {
#[derive(PartialEq)]
struct A<T: ?Sized>(T);
impl<'a, T, U> $Op<&'a mut A<U>> for &mut A<T>
where
T: $RefOp<&'a mut U>,
{
type Output = A<T::Output>;
fn $op(self, rhs: &'a mut A<U>) -> Self::Output {
A(self.0.$ref_op(&mut rhs.0))
}
}
fn f<T, U>(mut a: T, mut b: U)
where
for<'a, 'b> &'a mut T: $Op<&'b mut U>,
{
let _a_op_b = (&mut a).$op(&mut b);
}
fn g<T, U>(a: T, b: U)
where
for<'a, 'b> &'a mut T: $Op<&'b mut U>,
{
f(a, b);
}
g($dummy, $dummy);
assert!($assert);
}
};
}
test_binary!(
test_add,
Add,
add,
RefMutAdd,
ref_mut_add,
&mut A(B(1.0)) + &mut A(B(2.0)) == A(B(3.0)),
B(1.0)
);
test_binary!(
test_sub,
Sub,
sub,
RefMutSub,
ref_mut_sub,
&mut A(B(3.0)) - &mut A(B(1.0)) == A(B(2.0)),
B(1.0)
);
test_binary!(
test_mul,
Mul,
mul,
RefMutMul,
ref_mut_mul,
&mut A(B(2.0)) * &mut A(B(3.0)) == A(B(6.0)),
B(1.0)
);
test_binary!(
test_div,
Div,
div,
RefMutDiv,
ref_mut_div,
&mut A(B(6.0)) / &mut A(B(2.0)) == A(B(3.0)),
B(1.0)
);
test_binary!(
test_rem,
Rem,
rem,
RefMutRem,
ref_mut_rem,
&mut A(B(6.0)) % &mut A(B(4.0)) == A(B(2.0)),
B(1.0)
);
test_binary!(
test_shl,
Shl,
shl,
RefMutShl,
ref_mut_shl,
&mut A(B(3)) << &mut A(B(2)) == A(B(12)),
B(1)
);
test_binary!(
test_shr,
Shr,
shr,
RefMutShr,
ref_mut_shr,
&mut A(B(12)) >> &mut A(B(2)) == A(B(3)),
B(1)
);
test_binary!(
test_bitand,
BitAnd,
bitand,
RefMutBitAnd,
ref_mut_bitand,
&mut A(B(6)) & &mut A(B(5)) == A(B(4)),
B(1)
);
test_binary!(
test_bitor,
BitOr,
bitor,
RefMutBitOr,
ref_mut_bitor,
&mut A(B(3)) | &mut A(B(5)) == A(B(7)),
B(1)
);
test_binary!(
test_bitxor,
BitXor,
bitxor,
RefMutBitXor,
ref_mut_bitxor,
&mut A(B(3)) ^ &mut A(B(5)) == A(B(6)),
B(1)
);
}