redstone_ml/ops/
binary_ops.rs

1use std::ptr::addr_of;
2use crate::collapse_contiguous::collapse_to_uniform_stride;
3use crate::flat_index_generator::FlatIndexGenerator;
4use paste::paste;
5use std::ops::{BitAnd, BitOr, Rem, Shl, Shr};
6
7
8#[macro_export]
9macro_rules! define_binary_op_trait {
10    ($trait_name:ident, $required_trait:ident, $name:ident, $operator:tt; $($default_dtypes:ty),*) => {
11        define_binary_op_trait!($trait_name, $required_trait, $name, $operator);
12        impl_default_trait_for_dtypes!($trait_name, $($default_dtypes),*);
13    };
14
15    ($trait_name:ident, $required_trait:ident, $name:ident, $operator:tt) => {
16        paste! {
17            pub(crate) trait $trait_name: $required_trait<Output=Self> + Sized + Copy {
18                unsafe fn [<$name _stride_0_1>](lhs: *const Self,
19                                                rhs: *const Self,
20                                                dst: *mut Self, count: usize) {
21                    Self::[<$name _stride_n_n>](lhs, 0, rhs, 1, dst, count)
22                }
23
24                unsafe fn [<$name _stride_1_0>](lhs: *const Self,
25                                                rhs: *const Self, dst: *mut Self, count: usize) {
26                    Self::[<$name _stride_n_n>](lhs, 1, rhs, 0, dst, count)
27                }
28
29                unsafe fn [<$name _stride_n_0>](lhs: *const Self, lhs_stride: usize,
30                                                rhs: *const Self, dst: *mut Self, count: usize) {
31                    Self::[<$name _stride_n_n>](lhs, lhs_stride, rhs, 0, dst, count)
32                }
33
34                unsafe fn [<$name _stride_0_n>](lhs: *const Self,
35                                                rhs: *const Self, rhs_stride: usize,
36                                                dst: *mut Self, count: usize) {
37                    Self::[<$name _stride_n_n>](lhs, 0, rhs, rhs_stride, dst, count)
38                }
39
40                unsafe fn [<$name _stride_1_1>](lhs: *const Self, rhs: *const Self, dst: *mut Self, count: usize) {
41                    Self::[<$name _stride_n_n>](lhs, 1, rhs, 1, dst, count)
42                }
43
44                unsafe fn [<$name _stride_n_1>](lhs: *const Self, lhs_stride: usize,
45                                                rhs: *const Self, dst: *mut Self, count: usize) {
46                    Self::[<$name _stride_n_n>](lhs, lhs_stride, rhs, 1, dst, count)
47                }
48
49                unsafe fn [<$name _stride_1_n>](lhs: *const Self,
50                                                rhs: *const Self, rhs_stride: usize,
51                                                dst: *mut Self, count: usize) {
52                    Self::[<$name _stride_n_n>](lhs, 1, rhs, rhs_stride, dst, count)
53                }
54
55                #[inline(never)]
56                unsafe fn [<$name _stride_n_n>](mut lhs: *const Self, lhs_stride: usize,
57                                                mut rhs: *const Self, rhs_stride: usize,
58                                                mut dst: *mut Self, mut count: usize) {
59                    while count != 0 {
60                        *dst = *lhs $operator *rhs;
61
62                        count -= 1;
63                        lhs = lhs.add(lhs_stride);
64                        rhs = rhs.add(rhs_stride);
65                        dst = dst.add(1);
66                    }
67                }
68
69                unsafe fn [<$name _nonunif_0>](lhs: *const Self, lhs_shape: &[usize], lhs_stride: &[usize],
70                                               rhs: *const Self,
71                                               dst: *mut Self, count: usize) {
72                    Self::[<$name _nonunif_n>](lhs, lhs_shape, lhs_stride, rhs, 0, dst, count)
73                }
74
75                unsafe fn [<$name _0_nonunif>](lhs: *const Self,
76                                               rhs: *const Self, rhs_shape: &[usize], rhs_stride: &[usize],
77                                               dst: *mut Self, count: usize) {
78                    Self::[<$name _n_nonunif>](lhs, 0, rhs, rhs_shape, rhs_stride, dst, count)
79                }
80
81                unsafe fn [<$name _nonunif_1>](lhs: *const Self, lhs_shape: &[usize], lhs_stride: &[usize],
82                                               rhs: *const Self,
83                                               dst: *mut Self, count: usize) {
84                    Self::[<$name _nonunif_n>](lhs, lhs_shape, lhs_stride, rhs, 1, dst, count)
85                }
86
87                unsafe fn [<$name _1_nonunif>](lhs: *const Self,
88                                               rhs: *const Self, rhs_shape: &[usize], rhs_stride: &[usize],
89                                               dst: *mut Self, count: usize) {
90                    Self::[<$name _n_nonunif>](lhs, 1, rhs, rhs_shape, rhs_stride, dst, count)
91                }
92
93                unsafe fn [<$name _nonunif_n>](lhs: *const Self, lhs_shape: &[usize], lhs_stride: &[usize],
94                                               mut rhs: *const Self, rhs_stride: usize,
95                                               mut dst: *mut Self, mut count: usize) {
96                    let mut lhs_indices = FlatIndexGenerator::from(lhs_shape, lhs_stride);
97
98                    while count != 0 {
99                        let lhs_index = lhs_indices.next().unwrap_unchecked();
100                        *dst = *lhs.add(lhs_index) $operator *rhs;
101
102                        count -= 1;
103                        dst = dst.add(1);
104                        rhs = rhs.add(rhs_stride);
105                    }
106                }
107
108                unsafe fn [<$name _n_nonunif>](mut lhs: *const Self, lhs_stride: usize,
109                                               rhs: *const Self, rhs_shape: &[usize], rhs_stride: &[usize],
110                                               mut dst: *mut Self, mut count: usize) {
111                    let mut rhs_indices = FlatIndexGenerator::from(rhs_shape, rhs_stride);
112
113                    while count != 0 {
114                        let rhs_index = rhs_indices.next().unwrap_unchecked();
115                        *dst = *lhs $operator *rhs.add(rhs_index);
116
117                        count -= 1;
118                        dst = dst.add(1);
119                        lhs = lhs.add(lhs_stride);
120                    }
121                }
122
123                unsafe fn [<$name _unspecialized>](lhs: *const Self, lhs_shape: &[usize], lhs_stride: &[usize],
124                                                   rhs: *const Self, rhs_shape: &[usize], rhs_stride: &[usize],
125                                                   mut dst: *mut Self) {
126                    let lhs_indices = FlatIndexGenerator::from(lhs_shape, lhs_stride);
127                    let rhs_indices = FlatIndexGenerator::from(rhs_shape, rhs_stride);
128
129                    for (lhs_index, rhs_index) in lhs_indices.zip(rhs_indices) {
130                        *dst = *lhs.add(lhs_index) $operator *rhs.add(rhs_index);
131                        dst = dst.add(1);
132                    }
133                }
134
135                unsafe fn [<$name _scalar>](lhs: *const Self, lhs_shape: &[usize], lhs_stride: &[usize],
136                                            rhs: Self, dst: *mut Self) {
137                    // special case for scalar operands
138                    if lhs_stride.is_empty() {
139                        *dst = *lhs $operator rhs;
140                        return;
141                    }
142
143                    let rhs = addr_of!(rhs);
144
145                    let (lhs_shape, lhs_stride) = collapse_to_uniform_stride(lhs_shape, &lhs_stride);
146                    let lhs_dims = lhs_shape.len();
147                    let lhs_inner_stride = lhs_stride[lhs_dims - 1];
148
149                    if lhs_dims == 1 {
150                        if lhs_inner_stride == 1 {
151                            return Self::[<$name _stride_1_0>](lhs, rhs, dst, lhs_shape[0]);
152                        }
153                        else {
154                            return Self::[<$name _stride_n_0>](lhs, lhs_inner_stride, rhs, dst, lhs_shape[0]);
155                        }
156                    }
157
158                    let count = lhs_shape.iter().product();
159                    return Self::[<$name _nonunif_0>](lhs, &lhs_shape, &lhs_stride, rhs, dst, count);
160                }
161
162                unsafe fn $name(lhs: *const Self, lhs_stride: &[usize],
163                                rhs: *const Self, rhs_stride: &[usize],
164                                dst: *mut Self, shape: &[usize]) {
165                    // special case for scalar operands
166                    if lhs_stride.is_empty() && rhs_stride.is_empty() {
167                        *dst = *lhs $operator *rhs;
168                        return;
169                    }
170
171                    let (lhs_shape, lhs_stride) = collapse_to_uniform_stride(shape, &lhs_stride);
172                    let (rhs_shape, rhs_stride) = collapse_to_uniform_stride(shape, &rhs_stride);
173
174                    let lhs_dims = lhs_shape.len();
175                    let rhs_dims = rhs_shape.len();
176
177                    let lhs_inner_stride = lhs_stride[lhs_dims - 1];
178                    let rhs_inner_stride = rhs_stride[rhs_dims - 1];
179
180                    if lhs_dims == 1 && rhs_dims == 1 { // both operands have a uniform stride
181
182                        // one operand is a scalar
183                        if rhs_inner_stride == 0 {
184                            if lhs_inner_stride == 1 {
185                                return Self::[<$name _stride_1_0>](lhs, rhs, dst, lhs_shape[0]);
186                            }
187                            else {
188                                return Self::[<$name _stride_n_0>](lhs, lhs_inner_stride, rhs, dst, lhs_shape[0]);
189                            }
190
191                        } else if lhs_inner_stride == 0 {
192                            if rhs_inner_stride == 1 {
193                                return Self::[<$name _stride_0_1>](lhs, rhs, dst, rhs_shape[0]);
194                            }
195                            else {
196                                return Self::[<$name _stride_0_n>](lhs, rhs, rhs_inner_stride, dst, rhs_shape[0]);
197                            }
198                        }
199
200                        // both operands are contiguous
201                        if lhs_inner_stride == 1 && rhs_inner_stride == 1 {
202                            return Self::[<$name _stride_1_1>](lhs, rhs, dst, lhs_shape[0]);
203                        }
204
205                        if lhs_inner_stride == 1 {
206                            return Self::[<$name _stride_1_n>](lhs, rhs, rhs_inner_stride, dst, lhs_shape[0]);
207                        }
208                        else if rhs_inner_stride == 1 {
209                            return Self::[<$name _stride_n_1>](lhs, lhs_inner_stride, rhs, dst, rhs_shape[0]);
210                        }
211
212                        // neither element is contiguous
213                        return Self::[<$name _stride_n_n>](lhs, lhs_inner_stride, rhs, rhs_inner_stride, dst, lhs_shape[0]);
214                    }
215
216                    // only 1 operand has a uniform stride
217                    if rhs_dims == 1 && rhs_inner_stride == 0 {
218                        return Self::[<$name _nonunif_0>](lhs, &lhs_shape, &lhs_stride,
219                                                          rhs, dst, rhs_shape[0]);
220                    } else if lhs_dims == 1 && lhs_inner_stride == 0 {
221                        return Self::[<$name _0_nonunif>](lhs,
222                                                          rhs, &rhs_shape, &rhs_stride,
223                                                          dst, lhs_shape[0]);
224                    }
225
226                    if rhs_dims == 1 && rhs_inner_stride == 1 {
227                        return Self::[<$name _nonunif_1>](lhs, &lhs_shape, &lhs_stride,
228                                                          rhs, dst, rhs_shape[0]);
229                    } else if lhs_dims == 1 && lhs_inner_stride == 1 {
230                        return Self::[<$name _1_nonunif>](lhs,
231                                                          rhs, &rhs_shape, &rhs_stride,
232                                                          dst, lhs_shape[0]);
233                    }
234
235                    if rhs_dims == 1 {
236                        return Self::[<$name _nonunif_n>](lhs, &lhs_shape, &lhs_stride,
237                                                          rhs, rhs_inner_stride,
238                                                          dst, rhs_shape[0]);
239                    } else if lhs_dims == 1 {
240                        return Self::[<$name _n_nonunif>](lhs, lhs_inner_stride,
241                                                          rhs, &rhs_shape, &rhs_stride,
242                                                          dst, lhs_shape[0]);
243                    }
244
245                    // unspecialized loop
246                    Self::[<$name _unspecialized>](lhs, &lhs_shape, &lhs_stride,
247                                                   rhs, &rhs_shape, &rhs_stride,
248                                                   dst);
249                }
250            }
251        }
252    }
253}
254
255define_binary_op_trait!(BinaryOpRem, Rem, rem, %;
256                        i8, i16, i32, i64, i128, isize,
257                        u8, u16, u32, u64, u128, usize,
258                        f32, f64);
259
260define_binary_op_trait!(BinaryOpBitAnd, BitAnd, bitand, &;
261                        i8, i16, i32, i64, i128, isize,
262                        u8, u16, u32, u64, u128, usize);
263
264define_binary_op_trait!(BinaryOpBitOr, BitOr, bitor, |;
265                        i8, i16, i32, i64, i128, isize,
266                        u8, u16, u32, u64, u128, usize);
267
268define_binary_op_trait!(BinaryOpShl, Shl, shl, <<;
269                        i8, i16, i32, i64, i128, isize,
270                        u8, u16, u32, u64, u128, usize);
271
272define_binary_op_trait!(BinaryOpShr, Shr, shr, >>;
273                        i8, i16, i32, i64, i128, isize,
274                        u8, u16, u32, u64, u128, usize);