1use std::num::Wrapping;
7
8use crate::{U32SimdVec, f16, impl_f32_array_interface};
9
10use super::{F32SimdVec, I32SimdVec, SimdDescriptor, SimdMask, U8SimdVec, U16SimdVec};
11
12#[derive(Clone, Copy, Debug)]
13pub struct ScalarDescriptor;
14
15impl ScalarDescriptor {
16 #[inline]
17 pub fn from_token(_token: archmage::ScalarToken) -> Self {
18 Self
19 }
20}
21
22impl SimdDescriptor for ScalarDescriptor {
23 type F32Vec = f32;
24 type I32Vec = Wrapping<i32>;
25 type U32Vec = Wrapping<u32>;
26 type U8Vec = u8;
27 type U16Vec = u16;
28 type Mask = bool;
29 type Bf16Table8 = [f32; 8];
30
31 type Descriptor256 = Self;
32 type Descriptor128 = Self;
33
34 #[inline]
35 fn maybe_downgrade_256bit(self) -> Self::Descriptor256 {
36 self
37 }
38
39 #[inline]
40 fn maybe_downgrade_128bit(self) -> Self::Descriptor128 {
41 self
42 }
43
44 #[inline]
45 fn new() -> Option<Self> {
46 Some(Self)
47 }
48
49 fn call<R>(self, f: impl FnOnce(Self) -> R) -> R {
50 f(self)
52 }
53}
54
55impl F32SimdVec for f32 {
56 type Descriptor = ScalarDescriptor;
57
58 const LEN: usize = 1;
59
60 #[inline(always)]
61 fn load(_d: Self::Descriptor, mem: &[f32]) -> Self {
62 mem[0]
63 }
64
65 #[inline(always)]
66 fn store(&self, mem: &mut [f32]) {
67 mem[0] = *self;
68 }
69
70 #[inline(always)]
71 fn store_interleaved_2(a: Self, b: Self, dest: &mut [f32]) {
72 dest[0] = a;
73 dest[1] = b;
74 }
75
76 #[inline(always)]
77 fn store_interleaved_3(a: Self, b: Self, c: Self, dest: &mut [f32]) {
78 dest[0] = a;
79 dest[1] = b;
80 dest[2] = c;
81 }
82
83 #[inline(always)]
84 fn store_interleaved_4(a: Self, b: Self, c: Self, d: Self, dest: &mut [f32]) {
85 dest[0] = a;
86 dest[1] = b;
87 dest[2] = c;
88 dest[3] = d;
89 }
90
91 #[inline(always)]
92 fn store_interleaved_8(
93 a: Self,
94 b: Self,
95 c: Self,
96 d: Self,
97 e: Self,
98 f: Self,
99 g: Self,
100 h: Self,
101 dest: &mut [f32],
102 ) {
103 dest[0] = a;
104 dest[1] = b;
105 dest[2] = c;
106 dest[3] = d;
107 dest[4] = e;
108 dest[5] = f;
109 dest[6] = g;
110 dest[7] = h;
111 }
112
113 #[inline(always)]
114 fn load_deinterleaved_2(_d: Self::Descriptor, src: &[f32]) -> (Self, Self) {
115 (src[0], src[1])
116 }
117
118 #[inline(always)]
119 fn load_deinterleaved_3(_d: Self::Descriptor, src: &[f32]) -> (Self, Self, Self) {
120 (src[0], src[1], src[2])
121 }
122
123 #[inline(always)]
124 fn load_deinterleaved_4(_d: Self::Descriptor, src: &[f32]) -> (Self, Self, Self, Self) {
125 (src[0], src[1], src[2], src[3])
126 }
127
128 #[inline(always)]
129 fn mul_add(self, mul: Self, add: Self) -> Self {
130 (self * mul) + add
131 }
132
133 #[inline(always)]
134 fn neg_mul_add(self, mul: Self, add: Self) -> Self {
135 -(self * mul) + add
136 }
137
138 #[inline(always)]
139 fn splat(_d: Self::Descriptor, v: f32) -> Self {
140 v
141 }
142
143 #[inline(always)]
144 fn zero(_d: Self::Descriptor) -> Self {
145 0.0
146 }
147
148 #[inline(always)]
149 fn abs(self) -> Self {
150 self.abs()
151 }
152
153 #[inline(always)]
154 fn floor(self) -> Self {
155 self.floor()
156 }
157
158 #[inline(always)]
159 fn sqrt(self) -> Self {
160 self.sqrt()
161 }
162
163 #[inline(always)]
164 fn neg(self) -> Self {
165 -self
166 }
167
168 #[inline(always)]
169 fn copysign(self, sign: Self) -> Self {
170 self.copysign(sign)
171 }
172
173 #[inline(always)]
174 fn max(self, other: Self) -> Self {
175 self.max(other)
176 }
177
178 #[inline(always)]
179 fn min(self, other: Self) -> Self {
180 self.min(other)
181 }
182
183 #[inline(always)]
184 fn gt(self, other: Self) -> bool {
185 self > other
186 }
187
188 #[inline(always)]
189 fn as_i32(self) -> Wrapping<i32> {
190 Wrapping(self as i32)
191 }
192
193 #[inline(always)]
194 fn bitcast_to_i32(self) -> Wrapping<i32> {
195 Wrapping(self.to_bits() as i32)
196 }
197
198 #[inline(always)]
199 fn prepare_table_bf16_8(_d: Self::Descriptor, table: &[f32; 8]) -> [f32; 8] {
200 *table
202 }
203
204 #[inline(always)]
205 fn table_lookup_bf16_8(_d: Self::Descriptor, table: [f32; 8], indices: Wrapping<i32>) -> Self {
206 table[indices.0 as usize]
207 }
208
209 #[inline(always)]
210 fn round_store_u8(self, dest: &mut [u8]) {
211 dest[0] = self.round() as u8;
212 }
213
214 #[inline(always)]
215 fn round_store_u16(self, dest: &mut [u16]) {
216 dest[0] = self.round() as u16;
217 }
218
219 #[inline(always)]
220 fn load_f16_bits(_d: Self::Descriptor, mem: &[u16]) -> Self {
221 f16::from_bits(mem[0]).to_f32()
222 }
223
224 #[inline(always)]
225 fn store_f16_bits(self, dest: &mut [u16]) {
226 dest[0] = f16::from_f32(self).to_bits();
227 }
228
229 impl_f32_array_interface!();
230
231 #[inline(always)]
232 fn transpose_square(_d: Self::Descriptor, _data: &mut [Self::UnderlyingArray], _stride: usize) {
233 }
235}
236
237impl I32SimdVec for Wrapping<i32> {
238 type Descriptor = ScalarDescriptor;
239
240 const LEN: usize = 1;
241
242 #[inline(always)]
243 fn splat(_d: Self::Descriptor, v: i32) -> Self {
244 Wrapping(v)
245 }
246
247 #[inline(always)]
248 fn load(_d: Self::Descriptor, mem: &[i32]) -> Self {
249 Wrapping(mem[0])
250 }
251
252 #[inline(always)]
253 fn store(&self, mem: &mut [i32]) {
254 mem[0] = self.0;
255 }
256
257 #[inline(always)]
258 fn abs(self) -> Self {
259 Wrapping(self.0.abs())
260 }
261
262 #[inline(always)]
263 fn as_f32(self) -> f32 {
264 self.0 as f32
265 }
266
267 #[inline(always)]
268 fn bitcast_to_f32(self) -> f32 {
269 f32::from_bits(self.0 as u32)
270 }
271
272 #[inline(always)]
273 fn bitcast_to_u32(self) -> Wrapping<u32> {
274 Wrapping(self.0 as u32)
275 }
276
277 #[inline(always)]
278 fn gt(self, other: Self) -> bool {
279 self.0 > other.0
280 }
281
282 #[inline(always)]
283 fn lt_zero(self) -> bool {
284 self.0 < 0
285 }
286
287 #[inline(always)]
288 fn eq(self, other: Self) -> bool {
289 self.0 == other.0
290 }
291
292 #[inline(always)]
293 fn eq_zero(self) -> bool {
294 self.0 == 0
295 }
296
297 #[inline(always)]
298 fn shl<const AMOUNT_U: u32, const AMOUNT_I: i32>(self) -> Self {
299 Wrapping(self.0 << AMOUNT_U)
300 }
301
302 #[inline(always)]
303 fn shr<const AMOUNT_U: u32, const AMOUNT_I: i32>(self) -> Self {
304 Wrapping(self.0 >> AMOUNT_U)
305 }
306
307 #[inline(always)]
308 fn mul_wide_take_high(self, rhs: Self) -> Self {
309 Wrapping(((self.0 as i64 * rhs.0 as i64) >> 32) as i32)
310 }
311
312 #[inline(always)]
313 fn store_u16(self, dest: &mut [u16]) {
314 dest[0] = self.0 as u16;
315 }
316
317 #[inline(always)]
318 fn store_u8(self, dest: &mut [u8]) {
319 dest[0] = self.0 as u8;
320 }
321}
322
323impl U32SimdVec for Wrapping<u32> {
324 type Descriptor = ScalarDescriptor;
325
326 const LEN: usize = 1;
327
328 #[inline(always)]
329 fn bitcast_to_i32(self) -> Wrapping<i32> {
330 Wrapping(self.0 as i32)
331 }
332
333 #[inline(always)]
334 fn shr<const AMOUNT_U: u32, const AMOUNT_I: i32>(self) -> Self {
335 Wrapping(self.0 >> AMOUNT_U)
336 }
337}
338
339impl U8SimdVec for u8 {
340 type Descriptor = ScalarDescriptor;
341 const LEN: usize = 1;
342
343 #[inline(always)]
344 fn load(_d: Self::Descriptor, mem: &[u8]) -> Self {
345 mem[0]
346 }
347
348 #[inline(always)]
349 fn splat(_d: Self::Descriptor, v: u8) -> Self {
350 v
351 }
352
353 #[inline(always)]
354 fn store(&self, mem: &mut [u8]) {
355 mem[0] = *self;
356 }
357
358 #[inline(always)]
359 fn store_interleaved_2(a: Self, b: Self, dest: &mut [u8]) {
360 dest[0] = a;
361 dest[1] = b;
362 }
363
364 #[inline(always)]
365 fn store_interleaved_3(a: Self, b: Self, c: Self, dest: &mut [u8]) {
366 dest[0] = a;
367 dest[1] = b;
368 dest[2] = c;
369 }
370
371 #[inline(always)]
372 fn store_interleaved_4(a: Self, b: Self, c: Self, d: Self, dest: &mut [u8]) {
373 dest[0] = a;
374 dest[1] = b;
375 dest[2] = c;
376 dest[3] = d;
377 }
378}
379
380impl U16SimdVec for u16 {
381 type Descriptor = ScalarDescriptor;
382 const LEN: usize = 1;
383
384 #[inline(always)]
385 fn load(_d: Self::Descriptor, mem: &[u16]) -> Self {
386 mem[0]
387 }
388
389 #[inline(always)]
390 fn splat(_d: Self::Descriptor, v: u16) -> Self {
391 v
392 }
393
394 #[inline(always)]
395 fn store(&self, mem: &mut [u16]) {
396 mem[0] = *self;
397 }
398
399 #[inline(always)]
400 fn store_interleaved_2(a: Self, b: Self, dest: &mut [u16]) {
401 dest[0] = a;
402 dest[1] = b;
403 }
404
405 #[inline(always)]
406 fn store_interleaved_3(a: Self, b: Self, c: Self, dest: &mut [u16]) {
407 dest[0] = a;
408 dest[1] = b;
409 dest[2] = c;
410 }
411
412 #[inline(always)]
413 fn store_interleaved_4(a: Self, b: Self, c: Self, d: Self, dest: &mut [u16]) {
414 dest[0] = a;
415 dest[1] = b;
416 dest[2] = c;
417 dest[3] = d;
418 }
419}
420
421impl SimdMask for bool {
422 type Descriptor = ScalarDescriptor;
423
424 #[inline(always)]
425 fn if_then_else_f32(self, if_true: f32, if_false: f32) -> f32 {
426 if self { if_true } else { if_false }
427 }
428
429 #[inline(always)]
430 fn if_then_else_i32(self, if_true: Wrapping<i32>, if_false: Wrapping<i32>) -> Wrapping<i32> {
431 if self { if_true } else { if_false }
432 }
433
434 #[inline(always)]
435 fn maskz_i32(self, v: Wrapping<i32>) -> Wrapping<i32> {
436 if self { Wrapping(0) } else { v }
437 }
438
439 #[inline(always)]
440 fn all(self) -> bool {
441 self
442 }
443
444 #[inline(always)]
445 fn andnot(self, rhs: Self) -> Self {
446 (!self) & rhs
447 }
448}
449
450#[cfg(not(any(
451 target_arch = "x86_64",
452 target_arch = "aarch64",
453 target_arch = "wasm32"
454)))]
455#[macro_export]
456macro_rules! simd_function {
457 (
458 $dname:ident,
459 $descr:ident: $descr_ty:ident,
460 $(#[$($attr:meta)*])*
461 $pub:vis fn $name:ident($($arg:ident: $ty:ty),* $(,)?) $(-> $ret:ty )? $body: block
462 ) => {
463 #[inline(always)]
464 $(#[$($attr)*])*
465 $pub fn $name<$descr_ty: $crate::SimdDescriptor>($descr: $descr_ty, $($arg: $ty),*) $(-> $ret)? $body
466 $(#[$($attr)*])*
467 $pub fn $dname($($arg: $ty),*) $(-> $ret)? {
468 use $crate::SimdDescriptor;
469 $name($crate::ScalarDescriptor::new().unwrap(), $($arg),*)
470 }
471 };
472}
473
474#[cfg(not(any(
475 target_arch = "x86_64",
476 target_arch = "aarch64",
477 target_arch = "wasm32"
478)))]
479#[macro_export]
480macro_rules! test_all_instruction_sets {
481 (
482 $name:ident
483 ) => {
484 paste::paste! {
485 #[test]
486 fn [<$name _scalar>]() {
487 use $crate::SimdDescriptor;
488 $name($crate::ScalarDescriptor::new().unwrap())
489 }
490 }
491 };
492}
493
494#[cfg(not(any(
495 target_arch = "x86_64",
496 target_arch = "aarch64",
497 target_arch = "wasm32"
498)))]
499#[macro_export]
500macro_rules! bench_all_instruction_sets {
501 (
502 $name:ident,
503 $criterion:ident
504 ) => {
505 use $crate::SimdDescriptor;
506 $name(
507 $crate::ScalarDescriptor::new().unwrap(),
508 $criterion,
509 "scalar",
510 );
511 };
512}