ref-ops 0.2.5

An escape hatch for implementing `ops` traits for references to newtypes.
Documentation
use core::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Rem, Shl, Shr, Sub};

macro_rules! doc {
    ($( $x:expr, )* @$item:item) => {
        $( #[doc = $x] )*
        $item
    };
}

macro_rules! def_binary {
    ($Op:ident, $op:ident, $RefOp:ident, $ref_op:ident) => {
        mod $op {
            pub trait Sealed<Rhs = Self> {}
        }

        doc!(
            concat!("`", stringify!($op), "` operation through mutable references."),
            "",
            "As of Rust 1.73.0, the following code does not compile:",
            "```compile_fail",
            concat!("use core::ops::", stringify!($Op), ";"),
            "",
            "struct A<T>(T);",
            "",
            concat!("impl<'a, 'b, T, U> ", stringify!($Op), "<&'b mut A<U>> for &'a mut A<T>"),
            "where",
            concat!("    &'a mut T: ", stringify!($Op), "<&'b mut U>,"),
            "{",
            concat!("    type Output = A<<&'a mut T as ", stringify!($Op), "<&'b mut U>>::Output>;"),
            "",
            concat!("    fn ", stringify!($op), "(self, rhs: &'b mut A<U>) -> Self::Output {"),
            concat!("        A(self.0.", stringify!($op), "(&mut rhs.0))"),
            "    }",
            "}",
            "",
            "fn _f<T, U>(mut a: T, mut b: U)",
            "where",
            concat!("    for<'a, 'b> &'a mut T: ", stringify!($Op), "<&'b mut U>,"),
            "{",
            concat!("    let _a_op_b = (&mut a).", stringify!($op), "(&mut b);"),
            "",
            concat!("    // to do something with `a`, `b`, and `_a_op_b`"),
            "}",
            "",
            "fn _g<T, U>(a: T, b: U)",
            "where",
            concat!("    for<'a, 'b> &'a mut T: ", stringify!($Op), "<&'b mut U>,"),
            "{",
            "    _f(a, b);",
            "}",
            "```",
            "but the following code does:",
            "```",
            concat!("use core::ops::", stringify!($Op), ";"),
            concat!("use ref_ops::", stringify!($RefOp),";"),
            "",
            "struct A<T>(T);",
            "",
            concat!("impl<'a, T, U> ", stringify!($Op), "<&'a mut A<U>> for &mut A<T>"),
            "where",
            concat!("    T: ", stringify!($RefOp), "<&'a mut U>,"),
            "{",
            "    type Output = A<T::Output>;",
            "",
            concat!("    fn ", stringify!($op), "(self, rhs: &'a mut A<U>) -> Self::Output {"),
            concat!("        A(self.0.", stringify!($ref_op), "(&mut rhs.0))"),
            "    }",
            "}",
            "",
            "fn _f<T, U>(mut a: T, mut b: U)",
            "where",
            concat!("    for<'a, 'b> &'a mut T: ", stringify!($Op), "<&'b mut U>,"),
            "{",
            concat!("    let _a_op_b = (&mut a).", stringify!($op), "(&mut b);"),
            "",
            concat!("    // to do something with `a`, `b`, and `_a_op_b`"),
            "}",
            "",
            "fn _g<T, U>(a: T, b: U)",
            "where",
            concat!("    for<'a, 'b> &'a mut T: ", stringify!($Op), "<&'b mut U>,"),
            "{",
            "    _f(a, b);",
            "}",
            "```",
            @pub trait $RefOp<Rhs = Self>: $op::Sealed<Rhs> {
                doc!(
                    concat!("The resulting type after applying `", stringify!($op), "` operation."),
                    @type Output;
                );

                doc!(
                    concat!("Performs `", stringify!($op), "` operation."),
                    @fn $ref_op(&mut self, rhs: Rhs) -> Self::Output;
                );
            }
        );

        impl<T, Rhs, O> $op::Sealed<Rhs> for T
        where
            T: ?Sized,
            for<'a> &'a mut T: $Op<Rhs, Output = O>,
        {
        }

        impl<T, Rhs, O> $RefOp<Rhs> for T
        where
            T: ?Sized,
            for<'a> &'a mut T: $Op<Rhs, Output = O>,
        {
            type Output = O;

            fn $ref_op(&mut self, rhs: Rhs) -> O {
                self.$op(rhs)
            }
        }
    };
}

def_binary!(Add, add, RefMutAdd, ref_mut_add);
def_binary!(Sub, sub, RefMutSub, ref_mut_sub);
def_binary!(Mul, mul, RefMutMul, ref_mut_mul);
def_binary!(Div, div, RefMutDiv, ref_mut_div);
def_binary!(Rem, rem, RefMutRem, ref_mut_rem);
def_binary!(Shl, shl, RefMutShl, ref_mut_shl);
def_binary!(Shr, shr, RefMutShr, ref_mut_shr);
def_binary!(BitAnd, bitand, RefMutBitAnd, ref_mut_bitand);
def_binary!(BitOr, bitor, RefMutBitOr, ref_mut_bitor);
def_binary!(BitXor, bitxor, RefMutBitXor, ref_mut_bitxor);

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{
        RefAdd, RefBitAnd, RefBitOr, RefBitXor, RefDiv, RefMul, RefRem, RefShl, RefShr, RefSub,
    };

    #[derive(PartialEq)]
    struct B<T>(T);

    macro_rules! impl_binary {
        ($Op:ident, $op:ident, $RefOp:ident, $ref_op:ident) => {
            impl<'a, T, U> $Op<&'a mut B<U>> for &mut B<T>
            where
                T: $RefOp<&'a U>,
            {
                type Output = B<T::Output>;

                fn $op(self, rhs: &'a mut B<U>) -> Self::Output {
                    B(self.0.$ref_op(&rhs.0))
                }
            }
        };
    }

    impl_binary!(Add, add, RefAdd, ref_add);
    impl_binary!(Sub, sub, RefSub, ref_sub);
    impl_binary!(Mul, mul, RefMul, ref_mul);
    impl_binary!(Div, div, RefDiv, ref_div);
    impl_binary!(Rem, rem, RefRem, ref_rem);
    impl_binary!(Shl, shl, RefShl, ref_shl);
    impl_binary!(Shr, shr, RefShr, ref_shr);
    impl_binary!(BitAnd, bitand, RefBitAnd, ref_bitand);
    impl_binary!(BitOr, bitor, RefBitOr, ref_bitor);
    impl_binary!(BitXor, bitxor, RefBitXor, ref_bitxor);

    macro_rules! test_binary {
        ($fn:ident, $Op:ident, $op:ident, $RefOp:ident, $ref_op:ident, $assert:expr, $dummy:expr) => {
            #[test]
            fn $fn() {
                #[derive(PartialEq)]
                struct A<T: ?Sized>(T);

                impl<'a, T, U> $Op<&'a mut A<U>> for &mut A<T>
                where
                    T: $RefOp<&'a mut U>,
                {
                    type Output = A<T::Output>;

                    fn $op(self, rhs: &'a mut A<U>) -> Self::Output {
                        A(self.0.$ref_op(&mut rhs.0))
                    }
                }

                fn f<T, U>(mut a: T, mut b: U)
                where
                    for<'a, 'b> &'a mut T: $Op<&'b mut U>,
                {
                    let _a_op_b = (&mut a).$op(&mut b);

                    // to do something with `a` and `_a_op_b`
                }

                fn g<T, U>(a: T, b: U)
                where
                    for<'a, 'b> &'a mut T: $Op<&'b mut U>,
                {
                    f(a, b);
                }

                g($dummy, $dummy);

                assert!($assert);
            }
        };
    }

    test_binary!(
        test_add,
        Add,
        add,
        RefMutAdd,
        ref_mut_add,
        &mut A(B(1.0)) + &mut A(B(2.0)) == A(B(3.0)),
        B(1.0)
    );
    test_binary!(
        test_sub,
        Sub,
        sub,
        RefMutSub,
        ref_mut_sub,
        &mut A(B(3.0)) - &mut A(B(1.0)) == A(B(2.0)),
        B(1.0)
    );
    test_binary!(
        test_mul,
        Mul,
        mul,
        RefMutMul,
        ref_mut_mul,
        &mut A(B(2.0)) * &mut A(B(3.0)) == A(B(6.0)),
        B(1.0)
    );
    test_binary!(
        test_div,
        Div,
        div,
        RefMutDiv,
        ref_mut_div,
        &mut A(B(6.0)) / &mut A(B(2.0)) == A(B(3.0)),
        B(1.0)
    );
    test_binary!(
        test_rem,
        Rem,
        rem,
        RefMutRem,
        ref_mut_rem,
        &mut A(B(6.0)) % &mut A(B(4.0)) == A(B(2.0)),
        B(1.0)
    );
    test_binary!(
        test_shl,
        Shl,
        shl,
        RefMutShl,
        ref_mut_shl,
        &mut A(B(3)) << &mut A(B(2)) == A(B(12)),
        B(1)
    );
    test_binary!(
        test_shr,
        Shr,
        shr,
        RefMutShr,
        ref_mut_shr,
        &mut A(B(12)) >> &mut A(B(2)) == A(B(3)),
        B(1)
    );
    test_binary!(
        test_bitand,
        BitAnd,
        bitand,
        RefMutBitAnd,
        ref_mut_bitand,
        &mut A(B(6)) & &mut A(B(5)) == A(B(4)),
        B(1)
    );
    test_binary!(
        test_bitor,
        BitOr,
        bitor,
        RefMutBitOr,
        ref_mut_bitor,
        &mut A(B(3)) | &mut A(B(5)) == A(B(7)),
        B(1)
    );
    test_binary!(
        test_bitxor,
        BitXor,
        bitxor,
        RefMutBitXor,
        ref_mut_bitxor,
        &mut A(B(3)) ^ &mut A(B(5)) == A(B(6)),
        B(1)
    );
}