1use std::ptr::addr_of;
2use crate::collapse_contiguous::collapse_to_uniform_stride;
3use crate::flat_index_generator::FlatIndexGenerator;
4use paste::paste;
5use std::ops::{BitAnd, BitOr, Rem, Shl, Shr};
6
7
8#[macro_export]
9macro_rules! define_binary_op_trait {
10 ($trait_name:ident, $required_trait:ident, $name:ident, $operator:tt; $($default_dtypes:ty),*) => {
11 define_binary_op_trait!($trait_name, $required_trait, $name, $operator);
12 impl_default_trait_for_dtypes!($trait_name, $($default_dtypes),*);
13 };
14
15 ($trait_name:ident, $required_trait:ident, $name:ident, $operator:tt) => {
16 paste! {
17 pub(crate) trait $trait_name: $required_trait<Output=Self> + Sized + Copy {
18 unsafe fn [<$name _stride_0_1>](lhs: *const Self,
19 rhs: *const Self,
20 dst: *mut Self, count: usize) {
21 Self::[<$name _stride_n_n>](lhs, 0, rhs, 1, dst, count)
22 }
23
24 unsafe fn [<$name _stride_1_0>](lhs: *const Self,
25 rhs: *const Self, dst: *mut Self, count: usize) {
26 Self::[<$name _stride_n_n>](lhs, 1, rhs, 0, dst, count)
27 }
28
29 unsafe fn [<$name _stride_n_0>](lhs: *const Self, lhs_stride: usize,
30 rhs: *const Self, dst: *mut Self, count: usize) {
31 Self::[<$name _stride_n_n>](lhs, lhs_stride, rhs, 0, dst, count)
32 }
33
34 unsafe fn [<$name _stride_0_n>](lhs: *const Self,
35 rhs: *const Self, rhs_stride: usize,
36 dst: *mut Self, count: usize) {
37 Self::[<$name _stride_n_n>](lhs, 0, rhs, rhs_stride, dst, count)
38 }
39
40 unsafe fn [<$name _stride_1_1>](lhs: *const Self, rhs: *const Self, dst: *mut Self, count: usize) {
41 Self::[<$name _stride_n_n>](lhs, 1, rhs, 1, dst, count)
42 }
43
44 unsafe fn [<$name _stride_n_1>](lhs: *const Self, lhs_stride: usize,
45 rhs: *const Self, dst: *mut Self, count: usize) {
46 Self::[<$name _stride_n_n>](lhs, lhs_stride, rhs, 1, dst, count)
47 }
48
49 unsafe fn [<$name _stride_1_n>](lhs: *const Self,
50 rhs: *const Self, rhs_stride: usize,
51 dst: *mut Self, count: usize) {
52 Self::[<$name _stride_n_n>](lhs, 1, rhs, rhs_stride, dst, count)
53 }
54
55 #[inline(never)]
56 unsafe fn [<$name _stride_n_n>](mut lhs: *const Self, lhs_stride: usize,
57 mut rhs: *const Self, rhs_stride: usize,
58 mut dst: *mut Self, mut count: usize) {
59 while count != 0 {
60 *dst = *lhs $operator *rhs;
61
62 count -= 1;
63 lhs = lhs.add(lhs_stride);
64 rhs = rhs.add(rhs_stride);
65 dst = dst.add(1);
66 }
67 }
68
69 unsafe fn [<$name _nonunif_0>](lhs: *const Self, lhs_shape: &[usize], lhs_stride: &[usize],
70 rhs: *const Self,
71 dst: *mut Self, count: usize) {
72 Self::[<$name _nonunif_n>](lhs, lhs_shape, lhs_stride, rhs, 0, dst, count)
73 }
74
75 unsafe fn [<$name _0_nonunif>](lhs: *const Self,
76 rhs: *const Self, rhs_shape: &[usize], rhs_stride: &[usize],
77 dst: *mut Self, count: usize) {
78 Self::[<$name _n_nonunif>](lhs, 0, rhs, rhs_shape, rhs_stride, dst, count)
79 }
80
81 unsafe fn [<$name _nonunif_1>](lhs: *const Self, lhs_shape: &[usize], lhs_stride: &[usize],
82 rhs: *const Self,
83 dst: *mut Self, count: usize) {
84 Self::[<$name _nonunif_n>](lhs, lhs_shape, lhs_stride, rhs, 1, dst, count)
85 }
86
87 unsafe fn [<$name _1_nonunif>](lhs: *const Self,
88 rhs: *const Self, rhs_shape: &[usize], rhs_stride: &[usize],
89 dst: *mut Self, count: usize) {
90 Self::[<$name _n_nonunif>](lhs, 1, rhs, rhs_shape, rhs_stride, dst, count)
91 }
92
93 unsafe fn [<$name _nonunif_n>](lhs: *const Self, lhs_shape: &[usize], lhs_stride: &[usize],
94 mut rhs: *const Self, rhs_stride: usize,
95 mut dst: *mut Self, mut count: usize) {
96 let mut lhs_indices = FlatIndexGenerator::from(lhs_shape, lhs_stride);
97
98 while count != 0 {
99 let lhs_index = lhs_indices.next().unwrap_unchecked();
100 *dst = *lhs.add(lhs_index) $operator *rhs;
101
102 count -= 1;
103 dst = dst.add(1);
104 rhs = rhs.add(rhs_stride);
105 }
106 }
107
108 unsafe fn [<$name _n_nonunif>](mut lhs: *const Self, lhs_stride: usize,
109 rhs: *const Self, rhs_shape: &[usize], rhs_stride: &[usize],
110 mut dst: *mut Self, mut count: usize) {
111 let mut rhs_indices = FlatIndexGenerator::from(rhs_shape, rhs_stride);
112
113 while count != 0 {
114 let rhs_index = rhs_indices.next().unwrap_unchecked();
115 *dst = *lhs $operator *rhs.add(rhs_index);
116
117 count -= 1;
118 dst = dst.add(1);
119 lhs = lhs.add(lhs_stride);
120 }
121 }
122
123 unsafe fn [<$name _unspecialized>](lhs: *const Self, lhs_shape: &[usize], lhs_stride: &[usize],
124 rhs: *const Self, rhs_shape: &[usize], rhs_stride: &[usize],
125 mut dst: *mut Self) {
126 let lhs_indices = FlatIndexGenerator::from(lhs_shape, lhs_stride);
127 let rhs_indices = FlatIndexGenerator::from(rhs_shape, rhs_stride);
128
129 for (lhs_index, rhs_index) in lhs_indices.zip(rhs_indices) {
130 *dst = *lhs.add(lhs_index) $operator *rhs.add(rhs_index);
131 dst = dst.add(1);
132 }
133 }
134
135 unsafe fn [<$name _scalar>](lhs: *const Self, lhs_shape: &[usize], lhs_stride: &[usize],
136 rhs: Self, dst: *mut Self) {
137 if lhs_stride.is_empty() {
139 *dst = *lhs $operator rhs;
140 return;
141 }
142
143 let rhs = addr_of!(rhs);
144
145 let (lhs_shape, lhs_stride) = collapse_to_uniform_stride(lhs_shape, &lhs_stride);
146 let lhs_dims = lhs_shape.len();
147 let lhs_inner_stride = lhs_stride[lhs_dims - 1];
148
149 if lhs_dims == 1 {
150 if lhs_inner_stride == 1 {
151 return Self::[<$name _stride_1_0>](lhs, rhs, dst, lhs_shape[0]);
152 }
153 else {
154 return Self::[<$name _stride_n_0>](lhs, lhs_inner_stride, rhs, dst, lhs_shape[0]);
155 }
156 }
157
158 let count = lhs_shape.iter().product();
159 return Self::[<$name _nonunif_0>](lhs, &lhs_shape, &lhs_stride, rhs, dst, count);
160 }
161
162 unsafe fn $name(lhs: *const Self, lhs_stride: &[usize],
163 rhs: *const Self, rhs_stride: &[usize],
164 dst: *mut Self, shape: &[usize]) {
165 if lhs_stride.is_empty() && rhs_stride.is_empty() {
167 *dst = *lhs $operator *rhs;
168 return;
169 }
170
171 let (lhs_shape, lhs_stride) = collapse_to_uniform_stride(shape, &lhs_stride);
172 let (rhs_shape, rhs_stride) = collapse_to_uniform_stride(shape, &rhs_stride);
173
174 let lhs_dims = lhs_shape.len();
175 let rhs_dims = rhs_shape.len();
176
177 let lhs_inner_stride = lhs_stride[lhs_dims - 1];
178 let rhs_inner_stride = rhs_stride[rhs_dims - 1];
179
180 if lhs_dims == 1 && rhs_dims == 1 { if rhs_inner_stride == 0 {
184 if lhs_inner_stride == 1 {
185 return Self::[<$name _stride_1_0>](lhs, rhs, dst, lhs_shape[0]);
186 }
187 else {
188 return Self::[<$name _stride_n_0>](lhs, lhs_inner_stride, rhs, dst, lhs_shape[0]);
189 }
190
191 } else if lhs_inner_stride == 0 {
192 if rhs_inner_stride == 1 {
193 return Self::[<$name _stride_0_1>](lhs, rhs, dst, rhs_shape[0]);
194 }
195 else {
196 return Self::[<$name _stride_0_n>](lhs, rhs, rhs_inner_stride, dst, rhs_shape[0]);
197 }
198 }
199
200 if lhs_inner_stride == 1 && rhs_inner_stride == 1 {
202 return Self::[<$name _stride_1_1>](lhs, rhs, dst, lhs_shape[0]);
203 }
204
205 if lhs_inner_stride == 1 {
206 return Self::[<$name _stride_1_n>](lhs, rhs, rhs_inner_stride, dst, lhs_shape[0]);
207 }
208 else if rhs_inner_stride == 1 {
209 return Self::[<$name _stride_n_1>](lhs, lhs_inner_stride, rhs, dst, rhs_shape[0]);
210 }
211
212 return Self::[<$name _stride_n_n>](lhs, lhs_inner_stride, rhs, rhs_inner_stride, dst, lhs_shape[0]);
214 }
215
216 if rhs_dims == 1 && rhs_inner_stride == 0 {
218 return Self::[<$name _nonunif_0>](lhs, &lhs_shape, &lhs_stride,
219 rhs, dst, rhs_shape[0]);
220 } else if lhs_dims == 1 && lhs_inner_stride == 0 {
221 return Self::[<$name _0_nonunif>](lhs,
222 rhs, &rhs_shape, &rhs_stride,
223 dst, lhs_shape[0]);
224 }
225
226 if rhs_dims == 1 && rhs_inner_stride == 1 {
227 return Self::[<$name _nonunif_1>](lhs, &lhs_shape, &lhs_stride,
228 rhs, dst, rhs_shape[0]);
229 } else if lhs_dims == 1 && lhs_inner_stride == 1 {
230 return Self::[<$name _1_nonunif>](lhs,
231 rhs, &rhs_shape, &rhs_stride,
232 dst, lhs_shape[0]);
233 }
234
235 if rhs_dims == 1 {
236 return Self::[<$name _nonunif_n>](lhs, &lhs_shape, &lhs_stride,
237 rhs, rhs_inner_stride,
238 dst, rhs_shape[0]);
239 } else if lhs_dims == 1 {
240 return Self::[<$name _n_nonunif>](lhs, lhs_inner_stride,
241 rhs, &rhs_shape, &rhs_stride,
242 dst, lhs_shape[0]);
243 }
244
245 Self::[<$name _unspecialized>](lhs, &lhs_shape, &lhs_stride,
247 rhs, &rhs_shape, &rhs_stride,
248 dst);
249 }
250 }
251 }
252 }
253}
254
255define_binary_op_trait!(BinaryOpRem, Rem, rem, %;
256 i8, i16, i32, i64, i128, isize,
257 u8, u16, u32, u64, u128, usize,
258 f32, f64);
259
260define_binary_op_trait!(BinaryOpBitAnd, BitAnd, bitand, &;
261 i8, i16, i32, i64, i128, isize,
262 u8, u16, u32, u64, u128, usize);
263
264define_binary_op_trait!(BinaryOpBitOr, BitOr, bitor, |;
265 i8, i16, i32, i64, i128, isize,
266 u8, u16, u32, u64, u128, usize);
267
268define_binary_op_trait!(BinaryOpShl, Shl, shl, <<;
269 i8, i16, i32, i64, i128, isize,
270 u8, u16, u32, u64, u128, usize);
271
272define_binary_op_trait!(BinaryOpShr, Shr, shr, >>;
273 i8, i16, i32, i64, i128, isize,
274 u8, u16, u32, u64, u128, usize);