ref_ops/
ref_mut_binary.rs

1use core::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Rem, Shl, Shr, Sub};
2
3macro_rules! doc {
4    ($( $x:expr, )* @$item:item) => {
5        $( #[doc = $x] )*
6        $item
7    };
8}
9
10macro_rules! def_binary {
11    ($Op:ident, $op:ident, $RefOp:ident, $ref_op:ident) => {
12        mod $op {
13            pub trait Sealed<Rhs = Self> {}
14        }
15
16        doc!(
17            concat!("`", stringify!($op), "` operation through mutable references."),
18            "",
19            "As of Rust 1.73.0, the following code does not compile:",
20            "```compile_fail",
21            concat!("use core::ops::", stringify!($Op), ";"),
22            "",
23            "struct A<T>(T);",
24            "",
25            concat!("impl<'a, 'b, T, U> ", stringify!($Op), "<&'b mut A<U>> for &'a mut A<T>"),
26            "where",
27            concat!("    &'a mut T: ", stringify!($Op), "<&'b mut U>,"),
28            "{",
29            concat!("    type Output = A<<&'a mut T as ", stringify!($Op), "<&'b mut U>>::Output>;"),
30            "",
31            concat!("    fn ", stringify!($op), "(self, rhs: &'b mut A<U>) -> Self::Output {"),
32            concat!("        A(self.0.", stringify!($op), "(&mut rhs.0))"),
33            "    }",
34            "}",
35            "",
36            "fn _f<T, U>(mut a: T, mut b: U)",
37            "where",
38            concat!("    for<'a, 'b> &'a mut T: ", stringify!($Op), "<&'b mut U>,"),
39            "{",
40            concat!("    let _a_op_b = (&mut a).", stringify!($op), "(&mut b);"),
41            "",
42            concat!("    // to do something with `a`, `b`, and `_a_op_b`"),
43            "}",
44            "",
45            "fn _g<T, U>(a: T, b: U)",
46            "where",
47            concat!("    for<'a, 'b> &'a mut T: ", stringify!($Op), "<&'b mut U>,"),
48            "{",
49            "    _f(a, b);",
50            "}",
51            "```",
52            "but the following code does:",
53            "```",
54            concat!("use core::ops::", stringify!($Op), ";"),
55            concat!("use ref_ops::", stringify!($RefOp),";"),
56            "",
57            "struct A<T>(T);",
58            "",
59            concat!("impl<'a, T, U> ", stringify!($Op), "<&'a mut A<U>> for &mut A<T>"),
60            "where",
61            concat!("    T: ", stringify!($RefOp), "<&'a mut U>,"),
62            "{",
63            "    type Output = A<T::Output>;",
64            "",
65            concat!("    fn ", stringify!($op), "(self, rhs: &'a mut A<U>) -> Self::Output {"),
66            concat!("        A(self.0.", stringify!($ref_op), "(&mut rhs.0))"),
67            "    }",
68            "}",
69            "",
70            "fn _f<T, U>(mut a: T, mut b: U)",
71            "where",
72            concat!("    for<'a, 'b> &'a mut T: ", stringify!($Op), "<&'b mut U>,"),
73            "{",
74            concat!("    let _a_op_b = (&mut a).", stringify!($op), "(&mut b);"),
75            "",
76            concat!("    // to do something with `a`, `b`, and `_a_op_b`"),
77            "}",
78            "",
79            "fn _g<T, U>(a: T, b: U)",
80            "where",
81            concat!("    for<'a, 'b> &'a mut T: ", stringify!($Op), "<&'b mut U>,"),
82            "{",
83            "    _f(a, b);",
84            "}",
85            "```",
86            @pub trait $RefOp<Rhs = Self>: $op::Sealed<Rhs> {
87                doc!(
88                    concat!("The resulting type after applying `", stringify!($op), "` operation."),
89                    @type Output;
90                );
91
92                doc!(
93                    concat!("Performs `", stringify!($op), "` operation."),
94                    @fn $ref_op(&mut self, rhs: Rhs) -> Self::Output;
95                );
96            }
97        );
98
99        impl<T, Rhs, O> $op::Sealed<Rhs> for T
100        where
101            T: ?Sized,
102            for<'a> &'a mut T: $Op<Rhs, Output = O>,
103        {
104        }
105
106        impl<T, Rhs, O> $RefOp<Rhs> for T
107        where
108            T: ?Sized,
109            for<'a> &'a mut T: $Op<Rhs, Output = O>,
110        {
111            type Output = O;
112
113            fn $ref_op(&mut self, rhs: Rhs) -> O {
114                self.$op(rhs)
115            }
116        }
117    };
118}
119
120def_binary!(Add, add, RefMutAdd, ref_mut_add);
121def_binary!(Sub, sub, RefMutSub, ref_mut_sub);
122def_binary!(Mul, mul, RefMutMul, ref_mut_mul);
123def_binary!(Div, div, RefMutDiv, ref_mut_div);
124def_binary!(Rem, rem, RefMutRem, ref_mut_rem);
125def_binary!(Shl, shl, RefMutShl, ref_mut_shl);
126def_binary!(Shr, shr, RefMutShr, ref_mut_shr);
127def_binary!(BitAnd, bitand, RefMutBitAnd, ref_mut_bitand);
128def_binary!(BitOr, bitor, RefMutBitOr, ref_mut_bitor);
129def_binary!(BitXor, bitxor, RefMutBitXor, ref_mut_bitxor);
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use crate::{
135        RefAdd, RefBitAnd, RefBitOr, RefBitXor, RefDiv, RefMul, RefRem, RefShl, RefShr, RefSub,
136    };
137
138    #[derive(PartialEq)]
139    struct B<T>(T);
140
141    macro_rules! impl_binary {
142        ($Op:ident, $op:ident, $RefOp:ident, $ref_op:ident) => {
143            impl<'a, T, U> $Op<&'a mut B<U>> for &mut B<T>
144            where
145                T: $RefOp<&'a U>,
146            {
147                type Output = B<T::Output>;
148
149                fn $op(self, rhs: &'a mut B<U>) -> Self::Output {
150                    B(self.0.$ref_op(&rhs.0))
151                }
152            }
153        };
154    }
155
156    impl_binary!(Add, add, RefAdd, ref_add);
157    impl_binary!(Sub, sub, RefSub, ref_sub);
158    impl_binary!(Mul, mul, RefMul, ref_mul);
159    impl_binary!(Div, div, RefDiv, ref_div);
160    impl_binary!(Rem, rem, RefRem, ref_rem);
161    impl_binary!(Shl, shl, RefShl, ref_shl);
162    impl_binary!(Shr, shr, RefShr, ref_shr);
163    impl_binary!(BitAnd, bitand, RefBitAnd, ref_bitand);
164    impl_binary!(BitOr, bitor, RefBitOr, ref_bitor);
165    impl_binary!(BitXor, bitxor, RefBitXor, ref_bitxor);
166
167    macro_rules! test_binary {
168        ($fn:ident, $Op:ident, $op:ident, $RefOp:ident, $ref_op:ident, $assert:expr, $dummy:expr) => {
169            #[test]
170            fn $fn() {
171                #[derive(PartialEq)]
172                struct A<T: ?Sized>(T);
173
174                impl<'a, T, U> $Op<&'a mut A<U>> for &mut A<T>
175                where
176                    T: $RefOp<&'a mut U>,
177                {
178                    type Output = A<T::Output>;
179
180                    fn $op(self, rhs: &'a mut A<U>) -> Self::Output {
181                        A(self.0.$ref_op(&mut rhs.0))
182                    }
183                }
184
185                fn f<T, U>(mut a: T, mut b: U)
186                where
187                    for<'a, 'b> &'a mut T: $Op<&'b mut U>,
188                {
189                    let _a_op_b = (&mut a).$op(&mut b);
190
191                    // to do something with `a` and `_a_op_b`
192                }
193
194                fn g<T, U>(a: T, b: U)
195                where
196                    for<'a, 'b> &'a mut T: $Op<&'b mut U>,
197                {
198                    f(a, b);
199                }
200
201                g($dummy, $dummy);
202
203                assert!($assert);
204            }
205        };
206    }
207
208    test_binary!(
209        test_add,
210        Add,
211        add,
212        RefMutAdd,
213        ref_mut_add,
214        &mut A(B(1.0)) + &mut A(B(2.0)) == A(B(3.0)),
215        B(1.0)
216    );
217    test_binary!(
218        test_sub,
219        Sub,
220        sub,
221        RefMutSub,
222        ref_mut_sub,
223        &mut A(B(3.0)) - &mut A(B(1.0)) == A(B(2.0)),
224        B(1.0)
225    );
226    test_binary!(
227        test_mul,
228        Mul,
229        mul,
230        RefMutMul,
231        ref_mut_mul,
232        &mut A(B(2.0)) * &mut A(B(3.0)) == A(B(6.0)),
233        B(1.0)
234    );
235    test_binary!(
236        test_div,
237        Div,
238        div,
239        RefMutDiv,
240        ref_mut_div,
241        &mut A(B(6.0)) / &mut A(B(2.0)) == A(B(3.0)),
242        B(1.0)
243    );
244    test_binary!(
245        test_rem,
246        Rem,
247        rem,
248        RefMutRem,
249        ref_mut_rem,
250        &mut A(B(6.0)) % &mut A(B(4.0)) == A(B(2.0)),
251        B(1.0)
252    );
253    test_binary!(
254        test_shl,
255        Shl,
256        shl,
257        RefMutShl,
258        ref_mut_shl,
259        &mut A(B(3)) << &mut A(B(2)) == A(B(12)),
260        B(1)
261    );
262    test_binary!(
263        test_shr,
264        Shr,
265        shr,
266        RefMutShr,
267        ref_mut_shr,
268        &mut A(B(12)) >> &mut A(B(2)) == A(B(3)),
269        B(1)
270    );
271    test_binary!(
272        test_bitand,
273        BitAnd,
274        bitand,
275        RefMutBitAnd,
276        ref_mut_bitand,
277        &mut A(B(6)) & &mut A(B(5)) == A(B(4)),
278        B(1)
279    );
280    test_binary!(
281        test_bitor,
282        BitOr,
283        bitor,
284        RefMutBitOr,
285        ref_mut_bitor,
286        &mut A(B(3)) | &mut A(B(5)) == A(B(7)),
287        B(1)
288    );
289    test_binary!(
290        test_bitxor,
291        BitXor,
292        bitxor,
293        RefMutBitXor,
294        ref_mut_bitxor,
295        &mut A(B(3)) ^ &mut A(B(5)) == A(B(6)),
296        B(1)
297    );
298}