use core::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Rem, Shl, Shr, Sub};
macro_rules! doc {
($( $x:expr, )* @$item:item) => {
$( #[doc = $x] )*
$item
};
}
macro_rules! def_binary {
($Op:ident, $op:ident, $RefOp:ident, $ref_op:ident) => {
mod $op {
pub trait Sealed<Rhs = Self> {}
}
doc!(
concat!("`", stringify!($op), "` operation through references."),
"",
"As of Rust 1.72.1, the following code does not compile:",
"```compile_fail",
concat!("use core::ops::", stringify!($Op), ";"),
"",
"struct A<T>(T);",
"",
concat!("impl<'a, 'b, T, U, O> ", stringify!($Op), "<&'b A<U>> for &'a A<T>"),
"where",
concat!(" &'a T: ", stringify!($Op), "<&'b U, Output = O>,"),
"{",
concat!(" type Output = A<O>;"),
"",
concat!(" fn ", stringify!($op), "(self, rhs: &'b A<U>) -> Self::Output {"),
concat!(" A(self.0.", stringify!($op), "(&rhs.0))"),
" }",
"}",
"",
"fn _f<T, U>(a: T, b: U)",
"where",
concat!(" for<'a, 'b> &'a T: ", stringify!($Op), "<&'b U>,"),
"{",
concat!(" let _a_op_b = (&a).", stringify!($op), "(&b);"),
"",
concat!(" // to do something with `a`, `b`, and `_a_op_b`"),
"}",
"",
"fn _g<T, U>(a: T, b: U)",
"where",
concat!(" for<'a, 'b> &'a T: ", stringify!($Op), "<&'b U>,"),
"{",
" _f(a, b);",
"}",
"```",
"but the following code does:",
"```",
concat!("use core::ops::", stringify!($Op), ";"),
concat!("use ref_ops::", stringify!($RefOp),";"),
"",
"struct A<T>(T);",
"",
concat!("impl<'a, T, U> ", stringify!($Op), "<&'a A<U>> for &A<T>"),
"where",
concat!(" T: ", stringify!($RefOp), "<&'a U>,"),
"{",
" type Output = A<T::Output>;",
"",
concat!(" fn ", stringify!($op), "(self, rhs: &'a A<U>) -> Self::Output {"),
concat!(" A(self.0.", stringify!($ref_op), "(&rhs.0))"),
" }",
"}",
"",
"fn _f<T, U>(a: T, b: U)",
"where",
concat!(" for<'a, 'b> &'a T: ", stringify!($Op), "<&'b U>,"),
"{",
concat!(" let _a_op_b = (&a).", stringify!($op), "(&b);"),
"",
concat!(" // to do something with `a`, `b`, and `_a_op_b`"),
"}",
"",
"fn _g<T, U>(a: T, b: U)",
"where",
concat!(" for<'a, 'b> &'a T: ", stringify!($Op), "<&'b U>,"),
"{",
" _f(a, b);",
"}",
"```",
@pub trait $RefOp<Rhs = Self>: $op::Sealed<Rhs> {
doc!(
concat!("The resulting type after applying `", stringify!($op), "` operation."),
@type Output;
);
doc!(
concat!("Performs `", stringify!($op), "` operation."),
@fn $ref_op(&self, rhs: Rhs) -> Self::Output;
);
}
);
impl<T, Rhs, O> $op::Sealed<Rhs> for T
where
T: ?Sized,
for<'a> &'a T: $Op<Rhs, Output = O>,
{
}
impl<T, Rhs, O> $RefOp<Rhs> for T
where
T: ?Sized,
for<'a> &'a T: $Op<Rhs, Output = O>,
{
type Output = O;
fn $ref_op(&self, rhs: Rhs) -> O {
self.$op(rhs)
}
}
};
}
def_binary!(Add, add, RefAdd, ref_add);
def_binary!(Sub, sub, RefSub, ref_sub);
def_binary!(Mul, mul, RefMul, ref_mul);
def_binary!(Div, div, RefDiv, ref_div);
def_binary!(Rem, rem, RefRem, ref_rem);
def_binary!(Shl, shl, RefShl, ref_shl);
def_binary!(Shr, shr, RefShr, ref_shr);
def_binary!(BitAnd, bitand, RefBitAnd, ref_bitand);
def_binary!(BitOr, bitor, RefBitOr, ref_bitor);
def_binary!(BitXor, bitxor, RefBitXor, ref_bitxor);
#[cfg(test)]
mod tests {
use super::*;
macro_rules! test_binary {
($fn:ident, $Op:ident, $op:ident, $RefOp:ident, $ref_op:ident, $assert:expr, $dummy:literal) => {
#[test]
fn $fn() {
#[derive(PartialEq)]
struct A<T: ?Sized>(T);
impl<'a, T, U> $Op<&'a A<U>> for &A<T>
where
T: ?Sized + $RefOp<&'a U>,
{
type Output = A<T::Output>;
fn $op(self, rhs: &'a A<U>) -> Self::Output {
A(self.0.$ref_op(&rhs.0))
}
}
fn f<T, U>(a: T, b: U)
where
for<'a, 'b> &'a T: $Op<&'b U>,
{
let _a_op_b = (&a).$op(&b);
}
fn g<T, U>(a: T, b: U)
where
for<'a, 'b> &'a T: $Op<&'b U>,
{
f(a, b);
}
g($dummy, $dummy);
assert!($assert);
}
};
}
test_binary!(
test_add,
Add,
add,
RefAdd,
ref_add,
&A(1.0) + &A(2.0) == A(3.0),
1.0
);
test_binary!(
test_sub,
Sub,
sub,
RefSub,
ref_sub,
&A(3.0) - &A(1.0) == A(2.0),
1.0
);
test_binary!(
test_mul,
Mul,
mul,
RefMul,
ref_mul,
&A(2.0) * &A(3.0) == A(6.0),
1.0
);
test_binary!(
test_div,
Div,
div,
RefDiv,
ref_div,
&A(6.0) / &A(2.0) == A(3.0),
1.0
);
test_binary!(
test_rem,
Rem,
rem,
RefRem,
ref_rem,
&A(6.0) % &A(4.0) == A(2.0),
1.0
);
test_binary!(
test_shl,
Shl,
shl,
RefShl,
ref_shl,
&A(3) << &A(2) == A(12),
1
);
test_binary!(
test_shr,
Shr,
shr,
RefShr,
ref_shr,
&A(12) >> &A(2) == A(3),
1
);
test_binary!(
test_bitand,
BitAnd,
bitand,
RefBitAnd,
ref_bitand,
&A(6) & &A(5) == A(4),
1
);
test_binary!(
test_bitor,
BitOr,
bitor,
RefBitOr,
ref_bitor,
&A(3) | &A(5) == A(7),
1
);
test_binary!(
test_bitxor,
BitXor,
bitxor,
RefBitXor,
ref_bitxor,
&A(3) ^ &A(5) == A(6),
1
);
}