ref_ops/
ref_binary.rs

1use core::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Rem, Shl, Shr, Sub};
2
3macro_rules! doc {
4    ($( $x:expr, )* @$item:item) => {
5        $( #[doc = $x] )*
6        $item
7    };
8}
9
10macro_rules! def_binary {
11    ($Op:ident, $op:ident, $RefOp:ident, $ref_op:ident) => {
12        mod $op {
13            pub trait Sealed<Rhs = Self> {}
14        }
15
16        doc!(
17            concat!("`", stringify!($op), "` operation through references."),
18            "",
19            "As of Rust 1.73.0, the following code does not compile:",
20            "```compile_fail",
21            concat!("use core::ops::", stringify!($Op), ";"),
22            "",
23            "struct A<T>(T);",
24            "",
25            concat!("impl<'a, 'b, T, U> ", stringify!($Op), "<&'b A<U>> for &'a A<T>"),
26            "where",
27            concat!("    &'a T: ", stringify!($Op), "<&'b U>,"),
28            "{",
29            concat!("    type Output = A<<&'a T as ", stringify!($Op), "<&'b U>>::Output>;"),
30            "",
31            concat!("    fn ", stringify!($op), "(self, rhs: &'b A<U>) -> Self::Output {"),
32            concat!("        A(self.0.", stringify!($op), "(&rhs.0))"),
33            "    }",
34            "}",
35            "",
36            "fn _f<T, U>(a: T, b: U)",
37            "where",
38            concat!("    for<'a, 'b> &'a T: ", stringify!($Op), "<&'b U>,"),
39            "{",
40            concat!("    let _a_op_b = (&a).", stringify!($op), "(&b);"),
41            "",
42            concat!("    // to do something with `a`, `b`, and `_a_op_b`"),
43            "}",
44            "",
45            "fn _g<T, U>(a: T, b: U)",
46            "where",
47            concat!("    for<'a, 'b> &'a T: ", stringify!($Op), "<&'b U>,"),
48            "{",
49            "    _f(a, b);",
50            "}",
51            "```",
52            "but the following code does:",
53            "```",
54            concat!("use core::ops::", stringify!($Op), ";"),
55            concat!("use ref_ops::", stringify!($RefOp),";"),
56            "",
57            "struct A<T>(T);",
58            "",
59            concat!("impl<'a, T, U> ", stringify!($Op), "<&'a A<U>> for &A<T>"),
60            "where",
61            concat!("    T: ", stringify!($RefOp), "<&'a U>,"),
62            "{",
63            "    type Output = A<T::Output>;",
64            "",
65            concat!("    fn ", stringify!($op), "(self, rhs: &'a A<U>) -> Self::Output {"),
66            concat!("        A(self.0.", stringify!($ref_op), "(&rhs.0))"),
67            "    }",
68            "}",
69            "",
70            "fn _f<T, U>(a: T, b: U)",
71            "where",
72            concat!("    for<'a, 'b> &'a T: ", stringify!($Op), "<&'b U>,"),
73            "{",
74            concat!("    let _a_op_b = (&a).", stringify!($op), "(&b);"),
75            "",
76            concat!("    // to do something with `a`, `b`, and `_a_op_b`"),
77            "}",
78            "",
79            "fn _g<T, U>(a: T, b: U)",
80            "where",
81            concat!("    for<'a, 'b> &'a T: ", stringify!($Op), "<&'b U>,"),
82            "{",
83            "    _f(a, b);",
84            "}",
85            "```",
86            @pub trait $RefOp<Rhs = Self>: $op::Sealed<Rhs> {
87                doc!(
88                    concat!("The resulting type after applying `", stringify!($op), "` operation."),
89                    @type Output;
90                );
91
92                doc!(
93                    concat!("Performs `", stringify!($op), "` operation."),
94                    @fn $ref_op(&self, rhs: Rhs) -> Self::Output;
95                );
96            }
97        );
98
99        impl<T, Rhs, O> $op::Sealed<Rhs> for T
100        where
101            T: ?Sized,
102            for<'a> &'a T: $Op<Rhs, Output = O>,
103        {
104        }
105
106        impl<T, Rhs, O> $RefOp<Rhs> for T
107        where
108            T: ?Sized,
109            for<'a> &'a T: $Op<Rhs, Output = O>,
110        {
111            type Output = O;
112
113            fn $ref_op(&self, rhs: Rhs) -> O {
114                self.$op(rhs)
115            }
116        }
117    };
118}
119
120def_binary!(Add, add, RefAdd, ref_add);
121def_binary!(Sub, sub, RefSub, ref_sub);
122def_binary!(Mul, mul, RefMul, ref_mul);
123def_binary!(Div, div, RefDiv, ref_div);
124def_binary!(Rem, rem, RefRem, ref_rem);
125def_binary!(Shl, shl, RefShl, ref_shl);
126def_binary!(Shr, shr, RefShr, ref_shr);
127def_binary!(BitAnd, bitand, RefBitAnd, ref_bitand);
128def_binary!(BitOr, bitor, RefBitOr, ref_bitor);
129def_binary!(BitXor, bitxor, RefBitXor, ref_bitxor);
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    macro_rules! test_binary {
136        ($fn:ident, $Op:ident, $op:ident, $RefOp:ident, $ref_op:ident, $assert:expr, $dummy:literal) => {
137            #[test]
138            fn $fn() {
139                #[derive(PartialEq)]
140                struct A<T: ?Sized>(T);
141
142                impl<'a, T, U> $Op<&'a A<U>> for &A<T>
143                where
144                    T: ?Sized + $RefOp<&'a U>,
145                {
146                    type Output = A<T::Output>;
147
148                    fn $op(self, rhs: &'a A<U>) -> Self::Output {
149                        A(self.0.$ref_op(&rhs.0))
150                    }
151                }
152
153                fn f<T, U>(a: T, b: U)
154                where
155                    for<'a, 'b> &'a T: $Op<&'b U>,
156                {
157                    let _a_op_b = (&a).$op(&b);
158
159                    // to do something with `a` and `_a_op_b`
160                }
161
162                fn g<T, U>(a: T, b: U)
163                where
164                    for<'a, 'b> &'a T: $Op<&'b U>,
165                {
166                    f(a, b);
167                }
168
169                g($dummy, $dummy);
170
171                assert!($assert);
172            }
173        };
174    }
175
176    test_binary!(
177        test_add,
178        Add,
179        add,
180        RefAdd,
181        ref_add,
182        &A(1.0) + &A(2.0) == A(3.0),
183        1.0
184    );
185    test_binary!(
186        test_sub,
187        Sub,
188        sub,
189        RefSub,
190        ref_sub,
191        &A(3.0) - &A(1.0) == A(2.0),
192        1.0
193    );
194    test_binary!(
195        test_mul,
196        Mul,
197        mul,
198        RefMul,
199        ref_mul,
200        &A(2.0) * &A(3.0) == A(6.0),
201        1.0
202    );
203    test_binary!(
204        test_div,
205        Div,
206        div,
207        RefDiv,
208        ref_div,
209        &A(6.0) / &A(2.0) == A(3.0),
210        1.0
211    );
212    test_binary!(
213        test_rem,
214        Rem,
215        rem,
216        RefRem,
217        ref_rem,
218        &A(6.0) % &A(4.0) == A(2.0),
219        1.0
220    );
221    test_binary!(
222        test_shl,
223        Shl,
224        shl,
225        RefShl,
226        ref_shl,
227        &A(3) << &A(2) == A(12),
228        1
229    );
230    test_binary!(
231        test_shr,
232        Shr,
233        shr,
234        RefShr,
235        ref_shr,
236        &A(12) >> &A(2) == A(3),
237        1
238    );
239    test_binary!(
240        test_bitand,
241        BitAnd,
242        bitand,
243        RefBitAnd,
244        ref_bitand,
245        &A(6) & &A(5) == A(4),
246        1
247    );
248    test_binary!(
249        test_bitor,
250        BitOr,
251        bitor,
252        RefBitOr,
253        ref_bitor,
254        &A(3) | &A(5) == A(7),
255        1
256    );
257    test_binary!(
258        test_bitxor,
259        BitXor,
260        bitxor,
261        RefBitXor,
262        ref_bitxor,
263        &A(3) ^ &A(5) == A(6),
264        1
265    );
266}