diskann_wide/arch/mod.rs
1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! Traits and functions supporting multi-architecture applications.
7//!
8//! Many SIMD instructions are micro-architecture specific, meaning that only a subset of
9//! CPUs found in the wild can support SIMD accelerated algorithms. This module provides
10//! tools for writing SIMD algorithms supporting multiple architectures and provides a
11//! light-weight runtime dispatching service to select the most appropriate implementation
12//! at run time.
13//!
14//! The example code below demonstrates a multi-versioned `X = X + Y` kernel:
15//! ```rust
16//! use diskann_wide::arch::{Target2, dispatch2};
17//!
18//! // A zero-sized type that we can use to implement a trait.
19//! struct Add;
20//!
21//! impl<A: diskann_wide::Architecture> Target2<A, (), &mut [f32], &[f32]> for Add {
22//! #[inline]
23//! fn run(self, _: A, dst: &mut [f32], src: &[f32]) {
24//! std::iter::zip(dst.iter_mut(), src.iter()).for_each(|(d, s)| *d += *s);
25//! }
26//! }
27//!
28//! fn add(dst: &mut [f32], src: &[f32]) {
29//! dispatch2(Add, dst, src)
30//! }
31//!
32//! let mut dst = vec![1.0, 2.0, 3.0];
33//! add(&mut dst, &[2.0, 3.0, 4.0]);
34//! assert_eq!(dst, &[3.0, 5.0, 7.0]);
35//! ```
36//!
37//! Lets break down what's happening.
38//!
39//! The function [`dispatch2`] (suffixed with "2" because it takes two arguments, more on
40//! this later) takes the struct `Add`, which is required to implement [`Target2`] for all
41//! supported micro-architecture levels supported by `wide` for compilation target CPU.
42//!
43//! It then will determine at run time the features supported by the current CPU and invoke
44//! `Add::run` with the best architecture. The above example does not use any explicit SIMD
45//! and is generic with respect to the [`Architecture`]. This still allows the compiler to
46//! perform auto-vectorization for different platforms, which can still result in a speed-up.
47//!
48//! ## Mechanics of Dispatching
49//!
50//! Run time architecture detection happens only once (modulo race conditions) and once
51//! resolved involves an atomic load and a branch.
52//!
53//! ## Variadic traits and ABI
54//!
55//! The traits like [`Target1`] or [`dispatch2`] are suffixed by the number of additional
56//! arguments that they take. This is important for cases where we want the calling
57//! convention of the dispatched-to function to match the calling-convention of the
58//! dispatcher function. In this case, Rust will invoke the dispatched-to function using a
59//! jump instead of a function call.
60//!
61//! For example, if we had implemented the add method above without these variadics:
62//! ```rust
63//! use diskann_wide::arch::{Target, dispatch};
64//!
65//! struct AddV2<'a>(&'a mut [f32], &'a [f32]);
66//!
67//! impl<A: diskann_wide::Architecture> Target<A, ()> for AddV2<'_> {
68//! #[inline]
69//! fn run(self, _: A) {
70//! std::iter::zip(self.0.iter_mut(), self.1.iter()).for_each(|(d, s)| *d += *s);
71//! }
72//! }
73//!
74//! #[inline(never)]
75//! fn add_v2(dst: &mut [f32], src: &[f32]) {
76//! dispatch(AddV2(dst, src))
77//! }
78//!
79//! let mut dst = vec![1.0, 2.0, 3.0];
80//! add_v2(&mut dst, &[2.0, 3.0, 4.0]);
81//! assert_eq!(dst, &[3.0, 5.0, 7.0]);
82//! ```
83//! then the function still works, but the assembly code goes from something that looks like
84//! this
85//! ```asm
86//! mov rax, qword ptr [rip + diskann_wide::arch::x86_64::ARCH_NUMBER@GOTPCREL]
87//! mov rax, qword ptr [rax]
88//! cmp rax, 1
89//! je diskann_wide::arch::x86_64::V3::run_with_2
90//! test rax, rax
91//! je diskann_wide::arch::x86_64::dispatch_resolve2
92//! <scalar-code>
93//! ```
94//! where there are no stack writes and the `V3` compatible code is reached via `jmp` to
95//! something that looks like
96//! ```asm
97//! sub rsp, 56
98//! mov rax, qword ptr fs:[40]
99//! mov qword ptr [rsp + 48], rax
100//! mov rax, qword ptr [rip + diskann_wide::arch::x86_64::ARCH_NUMBER@GOTPCREL]
101//! mov rax, qword ptr [rax]
102//! test rax, rax
103//! je .LBB14_3
104//! cmp rax, 1
105//! jne .LBB14_2
106//! mov qword ptr [rsp + 8], rdi
107//! mov qword ptr [rsp + 16], 3
108//! lea rax, [rip + .L__unnamed_8]
109//! mov qword ptr [rsp + 24], rax
110//! mov qword ptr [rsp + 32], 3
111//! lea rax, [rsp + 7]
112//! mov qword ptr [rsp + 40], rax
113//! lea rdi, [rsp + 8]
114//! call diskann_wide::arch::x86_64::V3::run_with
115//! jmp .LBB14_5
116//!.LBB14_3:
117//! mov qword ptr [rsp + 8], rdi
118//! mov qword ptr [rsp + 16], 3
119//! lea rax, [rip + .L__unnamed_8]
120//! mov qword ptr [rsp + 24], rax
121//! mov qword ptr [rsp + 32], 3
122//! lea rdi, [rsp + 8]
123//! call diskann_wide::arch::x86_64::dispatch_resolve::<loop_example::AddV2, ()>
124//! jmp .LBB14_5
125//! <scalar-code>
126//! ```
127//! Notice some unconditional stack writes, stack preparation for calling the `V3` compatible
128//! code, and a `call` to run the code.
129//!
130//! What's happening is that Rust will not inline functions annotated with
131//! `target_feature(enable = "feature")]` into an incompatible context. Since `AddV3` exceeds
132//! 16 bytes, it must be passed on the stack (on Linux at least). Therefore, we have extra
133//! overhead of stack preparation to call the dispatch target.
134//!
135//! ## Function Pointer API
136//!
137//! The previous section discussed performing a dynamic dispatch at a single call site, but
138//! you need to pay the (admittedly small) overhead every time this function is called.
139//! Another approach would be to perform dispatch a single time by obtaining a function pointer
140//! to the dispatched function and then calling through that function pointer.
141//!
142//! This is a little tricky for two reasons (and a whole host of less obvious reasons).
143//!
144//! Reason 1: Functions with additional `target_features` cannot be inlined. This means the
145//! simple approach of
146//! ```rust
147//! use diskann_wide::{Architecture, arch::{Scalar, Target, Target2, dispatch2}};
148//!
149//! // A zero-sized type that we can use to implement a trait.
150//! struct Add;
151//!
152//! impl<A: Architecture> Target2<A, (), &mut [f32], &[f32]> for Add {
153//! #[inline]
154//! fn run(self, _: A, dst: &mut [f32], src: &[f32]) {
155//! std::iter::zip(dst.iter_mut(), src.iter()).for_each(|(d, s)| *d += *s);
156//! }
157//! }
158//!
159//! impl<A: Architecture> Target<A, fn(A, &mut [f32], &[f32])> for Add {
160//! fn run(self, arch: A) -> fn(A, &mut [f32], &[f32]) {
161//! // Create a non-capturing closure that invokes `arch.run2`.
162//! //
163//! // The invocation of `arch.run2` will apply the necessary target features and
164//! // the non-capturing closure can be coerced into a function pointer.
165//! let f = |arch: A, dst: &mut [f32], src: &[f32]| arch.run2(Add, dst, src);
166//! f
167//! }
168//! }
169//!
170//! let f: fn(Scalar, &mut [f32], &[f32]) = (Scalar).run(Add);
171//!
172//! let mut dst = vec![1.0, 2.0, 3.0];
173//! f(Scalar, &mut dst, &[2.0, 3.0, 4.0]);
174//! assert_eq!(dst, &[3.0, 5.0, 7.0]);
175//! ```
176//! would likely generate code that looks something like:
177//! ```asm
178//! .section .text.<<diskann_wide::Add as diskann_wide::arch::Target<_, _>::run::{closure#0} /* snip */>>
179//! .p2align 4
180//!.type <<diskann_wide::Add as _>::call_once,@function
181//!<<diskann_wide::Add as _>>::call_once:
182//! .cfi_startproc
183//! jmp <diskann_wide::arch::Scalar>::run_with_2::<diskann_wide::Add, &mut [f32], &[f32], ()>
184//! ```
185//! The body is simply an unconditional jump to the actual implementation precisely because
186//! the actual implementation cannot be inlined into the body of the closure we coerced into
187//! a function pointer. Unfortunately, the same applies to most other ways one would try to
188//! create a function pointer to the dispatched-to function.
189//!
190//! The consequence of this is that we need to take an **unsafe** function pointer so we
191//! can dispatch call directly to the implementation.
192//!
193//! Reason 2: Even if the above approach worked, the [`Architecture`] is still present in the
194//! signature of the `fn`, meaning we haven't really hidden the micro-architecture
195//! information.
196//!
197//! With that in mind, the current solution looks like the following.
198//! ```rust
199//! use diskann_wide::{
200//! Architecture,
201//! arch::{self, Dispatched2},
202//! lifetime::{Ref, Mut},
203//! };
204//!
205//! struct Add;
206//!
207//! // Note the use of `FTarget` instead of `Target`. That is because the implementation is
208//! // simply an associated function instead of a method.
209//! impl<A: Architecture> arch::FTarget2<A, (), &mut [f32], &[f32]> for Add {
210//! #[inline]
211//! fn run(_: A, dst: &mut [f32], src: &[f32]) {
212//! std::iter::zip(dst.iter_mut(), src.iter()).for_each(|(d, s)| *d += *s);
213//! }
214//! }
215//!
216//! // The `Dispatched2` struct is a slightly magical wrapper around a function pointer,
217//! // returning the unit type `()` and taking two arguments; one `&mut [f32]` and
218//! // the other `&[f32]`.
219//! //
220//! // The need for `Mut` and `Ref` is described below.
221//! type FnPtr = Dispatched2<(), Mut<[f32]>, Ref<[f32]>>;
222//!
223//! impl<A: Architecture> arch::Target<A, FnPtr> for Add {
224//! fn run(self, arch: A) -> FnPtr {
225//! arch.dispatch2::<Self, (), Mut<[f32]>, Ref<[f32]>>()
226//! }
227//! }
228//!
229//! let f: FnPtr = diskann_wide::arch::dispatch(Add);
230//! let mut dst = vec![1.0, 2.0, 3.0];
231//!
232//! // Invoke the function pointer
233//! f.call(&mut dst, &[2.0, 3.0, 4.0]);
234//! assert_eq!(dst, &[3.0, 5.0, 7.0]);
235//! ```
236//! The resulting function pointer (though "safely unsafe"), successfully hides that
237//! dispatched micro-architecture and when called will always go directly to the
238//! implementation rather than needing the trampoline.
239//!
240//! The [`Architecture`] methods
241//! * [`Architecture::dispatch1`], [`Architecture::dispatch2`], [`Architecture::dispatch3`]
242//!
243//! Will produce function pointers of the annotated arities
244//!
245//! * [`Dispatched1`], [`Dispatched2`], and [`Dispatched3`]
246//!
247//! and are accompanied by the generator traits
248//!
249//! * [`FTarget1`], [`FTarget2`], and [`FTarget3`]
250//!
251//! ### Lifetime Annotations and Limitations
252//!
253//! One thing to note in the above example is the use of [`crate::lifetime::Mut`] and
254//! [`crate::lifetime::Ref`] in the invocation of [`Architecture::dispatch2`]. This is an
255//! unfortunate limitation of the Rust compiler at the moment when it comes to inferring
256//! lifetimes of function pointers.
257//!
258//! For example, the following does not compile
259//! ```compile_fail
260//! pub struct Example;
261//!
262//! trait Run<T, U> {
263//! fn run(x: T, y: U);
264//! }
265//!
266//! impl Example {
267//! fn run<F, T, U>(self, x: T, y: U)
268//! where
269//! F: Run<T, U>,
270//! {
271//! F::run(x, y)
272//! }
273//! }
274//!
275//! struct Add;
276//!
277//! impl Run<&mut [f32], &[f32]> for Add {
278//! #[inline]
279//! fn run(dst: &mut [f32], src: &[f32]) {
280//! std::iter::zip(dst.iter_mut(), src.iter()).for_each(|(d, s)| *d += *s);
281//! }
282//! }
283//!
284//! // This fails to compile! :(
285//! pub fn make() -> fn(Example, &mut [f32], &[f32]) {
286//! let f = Example::run::<Add, &mut [f32], &[f32]>;
287//! f
288//! }
289//! ```
290//! Fails to compile with the following error message
291//! ```text
292//! error[E0308]: mismatched types
293//! --> <source>:27:5
294//! |
295//! 25 | pub fn make() -> fn(Example, &mut [f32], &[f32]) {
296//! | ------------------------------- expected `for<'a, 'b> fn(Example, &'a mut [f32], &'b [f32])` because of return type
297//! 26 | let f = Example::run::<Add, &mut [f32], &[f32]>;
298//! 27 | f
299//! | ^ one type is more general than the other
300//! |
301//! = note: expected fn pointer `for<'a, 'b> fn(Example, &'a mut _, &'b _)`
302//! found fn item `fn(Example, &mut _, &_) {Example::run::<Add, &mut [f32], &[f32]>}`
303//! ```
304//! The [`crate::lifetime::AddLifetime`] trait is the only solution the author found at time
305//! of writing that both avoids the trampoline the closure-like approaches induce and the above
306//! lifetime error. This leads to a few practical limitations:
307//!
308//! * Types passed through the function pointer interface can have at most a single lifetime.
309//! * The lifetimes of types passed through the function pointer interface must all be disjoint.
310//! * The return type cannot have a lifetime.
311//!
312//! Practically, these limitations are acceptable because micro-architecture dispatching is
313//! almost always the result of doing mathematical manipulation on primitive types and thus
314//! the lifetimes of the associated types are generally not complicated.
315//!
316//! ## Obtaining the [`Current`] Architecture
317//!
318//! When Rust crates are compiled, they can be provided with a target CPU. The function
319//! [`current`], the type [`Current`], and the constant [`crate::ARCH`] are all populated with
320//! the best matching wide [`Architecture`] selected at compile time.
321//!
322//! ## Hierarchies
323//!
324//! Each [`Architecture`] exposes a [`Level`] via [`Architecture::level()`] that
325//! can be used to compare capabilities without instantiating the architecture.
326//!
327//! ### X86
328//!
329//! * [`x86_64::V4`]: Supporting AVX-512 (and AVX2 and lower).
330//! * [`x86_64::V3`]: Supporting AVX2 and lower.
331//! * [`Scalar`]: Fallback architecture.
332//!
333//! The ordering is `Scalar` < `V3` < `V4`.
334//!
335//! ### Arm
336//!
337//! Currently, Arm support is limited to [`Scalar`].
338
339use half::f16;
340
341use crate::{
342 Const, SIMDCast, SIMDDotProduct, SIMDFloat, SIMDMask, SIMDSelect, SIMDSigned, SIMDSumTree,
343 SIMDUnsigned, SIMDVector, SplitJoin, ZipUnzip, lifetime::AddLifetime,
344};
345
346pub(crate) mod emulated;
347
348/// An [`Architecture`] that implements all operation as scalar loops, relying on the
349/// compiler for optimization.
350pub use emulated::Scalar;
351
352/// An opaque representation of an [`Architecture`]'s capability level.
353///
354/// `Level` allows comparing the relative capabilities of different architectures
355/// without requiring an instance of the architecture type. This is useful for
356/// compile-time checks against [`crate::ARCH`] where constructing architecture
357/// types like [`x86_64::V3`] would require `unsafe`.
358///
359/// Levels are totally ordered within an ISA family, with greater values indicating
360/// more capable instruction sets. [`Scalar`] is always the lowest level.
361///
362/// # Examples
363///
364/// Checking if the compile-time architecture meets a minimum capability:
365///
366/// ```
367/// #[cfg(target_arch = "x86_64")]
368/// use diskann_wide::{Architecture, arch};
369///
370/// // Check at compile time whether we were built with AVX2+ support.
371/// #[cfg(target_arch = "x86_64")]
372/// let _meets_v3 = arch::Current::level() >= arch::x86_64::V3::level();
373/// ```
374#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
375pub struct Level(LevelInner);
376
377impl Level {
378 const fn scalar() -> Self {
379 Self(LevelInner::Scalar)
380 }
381}
382
383cfg_if::cfg_if! {
384 if #[cfg(target_arch = "x86_64")] {
385 // Delegate to the architecture selection within the `x86_64` module.
386 pub mod x86_64;
387
388 use x86_64::LevelInner;
389
390 pub use x86_64::current;
391 pub use x86_64::Current;
392
393 pub use x86_64::dispatch;
394 pub use x86_64::dispatch1;
395 pub use x86_64::dispatch2;
396 pub use x86_64::dispatch3;
397
398 pub use x86_64::dispatch_no_features;
399 pub use x86_64::dispatch1_no_features;
400 pub use x86_64::dispatch2_no_features;
401 pub use x86_64::dispatch3_no_features;
402
403 impl Level {
404 const fn v3() -> Self {
405 Self(LevelInner::V3)
406 }
407
408 const fn v4() -> Self {
409 Self(LevelInner::V4)
410 }
411 }
412 } else if #[cfg(target_arch = "aarch64")] {
413 // Delegate to the architecture selection within the `aarch64` module.
414 pub mod aarch64;
415
416 use aarch64::LevelInner;
417
418 pub use aarch64::current;
419 pub use aarch64::Current;
420
421 pub use aarch64::dispatch;
422 pub use aarch64::dispatch1;
423 pub use aarch64::dispatch2;
424 pub use aarch64::dispatch3;
425
426 pub use aarch64::dispatch_no_features;
427 pub use aarch64::dispatch1_no_features;
428 pub use aarch64::dispatch2_no_features;
429 pub use aarch64::dispatch3_no_features;
430
431 impl Level {
432 const fn neon() -> Self {
433 Self(LevelInner::Neon)
434 }
435 }
436 } else {
437 pub type Current = Scalar;
438
439 // There is only one architecture present in this mode.
440 #[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
441 enum LevelInner {
442 Scalar,
443 }
444
445 pub const fn current() -> Current {
446 Scalar::new()
447 }
448
449 /// Run the target functor.
450 ///
451 /// In scalar mode, this does nothing special.
452 pub fn dispatch<T, R>(f: T) -> R
453 where T: Target<Scalar, R> {
454 f.run(Scalar::new())
455 }
456
457 /// Run the target functor.
458 ///
459 /// In scalar mode, this does nothing special.
460 pub fn dispatch1<T, T0, R>(f: T, x0: T0) -> R
461 where T: Target1<Scalar, R, T0> {
462 f.run(Scalar::new(), x0)
463 }
464
465 /// Run the target functor.
466 ///
467 /// In scalar mode, this does nothing special.
468 pub fn dispatch2<T, T0, T1, R>(f: T, x0: T0, x1: T1) -> R
469 where T: Target2<Scalar, R, T0, T1> {
470 f.run(Scalar::new(), x0, x1)
471 }
472
473 /// Run the target functor.
474 ///
475 /// In scalar mode, this does nothing special.
476 pub fn dispatch3<T, T0, T1, T2, R>(f: T, x0: T0, x1: T1, x2: T2) -> R
477 where T: Target3<Scalar, R, T0, T1, T2> {
478 f.run(Scalar::new(), x0, x1, x2)
479 }
480
481 /// Run the target functor.
482 ///
483 /// In scalar mode, this does nothing special.
484 pub fn dispatch_no_features<T, R>(f: T) -> R
485 where T: Target<Scalar, R> {
486 f.run(Scalar::new())
487 }
488
489 /// Run the target functor.
490 ///
491 /// In scalar mode, this does nothing special.
492 pub fn dispatch1_no_features<T, T0, R>(f: T, x0: T0) -> R
493 where T: Target1<Scalar, R, T0> {
494 f.run(Scalar::new(), x0)
495 }
496
497 /// Run the target functor.
498 ///
499 /// In scalar mode, this does nothing special.
500 pub fn dispatch2_no_features<T, T0, T1, R>(f: T, x0: T0, x1: T1) -> R
501 where T: Target2<Scalar, R, T0, T1> {
502 f.run(Scalar::new(), x0, x1)
503 }
504
505 /// Run the target functor.
506 ///
507 /// In scalar mode, this does nothing special.
508 pub fn dispatch3_no_features<T, T0, T1, T2, R>(f: T, x0: T0, x1: T1, x2: T2) -> R
509 where T: Target3<Scalar, R, T0, T1, T2> {
510 f.run(Scalar::new(), x0, x1, x2)
511 }
512 }
513}
514
515mod sealed {
516 pub trait Sealed: std::fmt::Debug + Copy + PartialEq + Send + Sync + 'static {}
517}
518
519pub(crate) use sealed::Sealed;
520
521macro_rules! vector {
522 ($me:ident: <$self:ident, $T:ty, $N:literal, $mask:ident> + $($rest:tt)*) => {
523 type $me: SIMDVector<Arch = $self, Scalar = $T, ConstLanes = Const<$N>, Mask = Self::$mask> + $($rest)*;
524 }
525}
526
527#[allow(non_camel_case_types)]
528pub trait Architecture: sealed::Sealed {
529 // mask types
530 type mask_f16x8: SIMDMask;
531 type mask_f16x16: SIMDMask;
532
533 type mask_f32x4: SIMDMask + SIMDSelect<Self::f32x4>;
534 type mask_f32x8: SIMDMask + SIMDSelect<Self::f32x8>;
535 type mask_f32x16: SIMDMask + SIMDSelect<Self::f32x16>;
536
537 type mask_i8x16: SIMDMask;
538 type mask_i8x32: SIMDMask;
539 type mask_i8x64: SIMDMask;
540
541 type mask_i16x8: SIMDMask;
542 type mask_i16x16: SIMDMask;
543 type mask_i16x32: SIMDMask;
544
545 type mask_i32x4: SIMDMask;
546 type mask_i32x8: SIMDMask + From<Self::mask_f32x8> + SIMDSelect<Self::i32x8>;
547 type mask_i32x16: SIMDMask + SIMDSelect<Self::i32x16>;
548
549 type mask_u8x16: SIMDMask;
550 type mask_u8x32: SIMDMask;
551 type mask_u8x64: SIMDMask;
552
553 type mask_u32x4: SIMDMask;
554 type mask_u32x8: SIMDMask + From<Self::mask_f32x8>;
555 type mask_u32x16: SIMDMask + SIMDSelect<Self::u32x16>;
556 type mask_u64x2: SIMDMask;
557 type mask_u64x4: SIMDMask;
558
559 /////////////////
560 //-- vectors --//
561 /////////////////
562
563 // floats
564 vector!(
565 f16x8: <Self, f16, 8, mask_f16x8>
566 + SIMDCast<f32, Cast = Self::f32x8>
567 );
568 vector!(
569 f16x16: <Self, f16, 16, mask_f16x16>
570 + SplitJoin<Halved = Self::f16x8>
571 + ZipUnzip<Halved = Self::f16x8>
572 + SIMDCast<f32, Cast = Self::f32x16>
573 );
574
575 vector!(
576 f32x4: <Self, f32, 4, mask_f32x4>
577 + SIMDFloat
578 + SIMDSumTree
579 );
580 vector!(
581 f32x8: <Self, f32, 8, mask_f32x8>
582 + SIMDFloat
583 + SIMDSumTree
584 + SIMDCast<f16, Cast = Self::f16x8>
585 + SplitJoin<Halved = Self::f32x4>
586 + From<Self::f16x8>
587 );
588 vector!(
589 f32x16: <Self, f32, 16, mask_f32x16>
590 + SIMDFloat
591 + SplitJoin<Halved = Self::f32x8>
592 + SIMDSumTree
593 + From<Self::f16x16>
594 );
595
596 // signed-integer
597 vector!(
598 i8x16: <Self, i8, 16, mask_i8x16>
599 + SIMDSigned
600 );
601 vector!(
602 i8x32: <Self, i8, 32, mask_i8x32>
603 + SIMDSigned
604 + SplitJoin<Halved = Self::i8x16>
605 + ZipUnzip<Halved = Self::i8x16>
606 );
607 vector!(
608 i8x64: <Self, i8, 64, mask_i8x64>
609 + SIMDSigned
610 );
611
612 vector!(
613 i16x8: <Self, i16, 8, mask_i16x8>
614 + SIMDSigned
615 );
616 vector!(
617 i16x16: <Self, i16, 16, mask_i16x16>
618 + SIMDSigned
619 + SplitJoin<Halved = Self::i16x8>
620 + ZipUnzip<Halved = Self::i16x8>
621 + From<Self::i8x16>
622 + From<Self::u8x16>
623 );
624 vector!(
625 i16x32: <Self, i16, 32, mask_i16x32>
626 + SIMDSigned
627 + SplitJoin<Halved = Self::i16x16>
628 + From<Self::i8x32>
629 + From<Self::u8x32>
630 );
631
632 vector!(
633 i32x4: <Self, i32, 4, mask_i32x4>
634 + SIMDSigned
635 );
636 vector!(
637 i32x8: <Self, i32, 8, mask_i32x8>
638 + SIMDSigned
639 + SIMDSumTree
640 + SplitJoin<Halved = Self::i32x4>
641 + ZipUnzip<Halved = Self::i32x4>
642 + SIMDDotProduct<Self::i16x16>
643 + SIMDDotProduct<Self::u8x32, Self::i8x32>
644 + SIMDDotProduct<Self::i8x32, Self::u8x32>
645 + SIMDCast<f32, Cast = Self::f32x8>
646 );
647 vector!(
648 i32x16: <Self, i32, 16, mask_i32x16>
649 + SIMDSigned
650 + SIMDSumTree
651 + SplitJoin<Halved = Self::i32x8>
652 + SIMDDotProduct<Self::u8x64, Self::i8x64>
653 + SIMDDotProduct<Self::i8x64, Self::u8x64>
654 );
655
656 // unsigned-integer
657 vector!(
658 u8x16: <Self, u8, 16, mask_u8x16>
659 + SIMDUnsigned
660 );
661 vector!(
662 u8x32: <Self, u8, 32, mask_u8x32>
663 + SIMDUnsigned
664 + SplitJoin<Halved = Self::u8x16>
665 + ZipUnzip<Halved = Self::u8x16>
666 );
667
668 vector!(
669 u8x64: <Self, u8, 64, mask_u8x64>
670 + SIMDUnsigned
671 );
672
673 vector!(
674 u32x4: <Self, u32, 4, mask_u32x4>
675 + SIMDUnsigned
676 );
677 vector!(
678 u32x8: <Self, u32, 8, mask_u32x8>
679 + SplitJoin<Halved = Self::u32x4>
680 + ZipUnzip<Halved = Self::u32x4>
681 + SIMDUnsigned
682 + SIMDSumTree
683 );
684 vector!(
685 u32x16: <Self, u32, 16, mask_u32x16>
686 + SIMDUnsigned
687 + SIMDSumTree
688 + SplitJoin<Halved = Self::u32x8>
689 );
690
691 vector!(
692 u64x2: <Self, u64, 2, mask_u64x2>
693 + SIMDUnsigned
694 );
695 vector!(
696 u64x4: <Self, u64, 4, mask_u64x4>
697 + SplitJoin<Halved = Self::u64x2>
698 + SIMDUnsigned
699 );
700
701 //---------//
702 // Methods //
703 //---------//
704
705 /// Return an opaque [`Level`] representing the capabilities of this architecture.
706 ///
707 /// Levels that compare greater represent architectures that are more capable.
708 ///
709 /// # Examples
710 ///
711 /// ```
712 /// use diskann_wide::{Architecture, arch};
713 ///
714 /// // Scalar is the baseline — every other architecture compares greater.
715 /// assert_eq!(arch::Scalar::level(), arch::Scalar::level());
716 ///
717 /// #[cfg(target_arch = "x86_64")]
718 /// assert!(arch::Scalar::level() < arch::x86_64::V3::level());
719 /// ```
720 fn level() -> Level;
721
722 /// Run the provided closure targeting this architecture.
723 ///
724 /// This function is always safe to call, but the function `f` likely needs to be
725 /// inlined into `run` for the correct target features to be applied.
726 fn run<F, R>(self, f: F) -> R
727 where
728 F: Target<Self, R>;
729
730 /// Run the provided closure targeting this architecture with an inlining hint.
731 ///
732 /// This function is always safe to call, but the function `f` likely needs to be
733 /// inlined into `run` for the correct target features to be applied.
734 ///
735 /// Note that although an inline hint is applied, it is not guaranteed that this call
736 /// will be inlined due to the interaction of `target_features`. If you really need `F`
737 /// to be inlined, you can call its `Target` method directly, but care must be taken
738 /// because this will not reapply `target_features`.
739 fn run_inline<F, R>(self, f: F) -> R
740 where
741 F: Target<Self, R>;
742
743 /// Run the provided closure targeting this architecture with an additional argument.
744 ///
745 /// This function is always safe to call, but the function `f` likely needs to be
746 /// inlined into `run` for the correct target features to be applied.
747 fn run1<F, T0, R>(self, f: F, x0: T0) -> R
748 where
749 F: Target1<Self, R, T0>;
750
751 /// Run the provided closure targeting this architecture with an additional argument and
752 /// an inlining hint.
753 ///
754 /// This function is always safe to call, but the function `f` likely needs to be
755 /// inlined into `run` for the correct target features to be applied.
756 ///
757 /// Note that although an inline hint is applied, it is not guaranteed that this call
758 /// will be inlined due to the interaction of `target_features`. If you really need `F`
759 /// to be inlined, you can call its `Target1` method directly, but care must be taken
760 /// because this will not reapply `target_features`.
761 fn run1_inline<F, T0, R>(self, f: F, x0: T0) -> R
762 where
763 F: Target1<Self, R, T0>;
764
765 /// Run the provided closure targeting this architecture with two additional arguments.
766 ///
767 /// This function is always safe to call, but the function `f` likely needs to be
768 /// inlined into `run` for the correct target features to be applied.
769 fn run2<F, T0, T1, R>(self, f: F, x0: T0, x1: T1) -> R
770 where
771 F: Target2<Self, R, T0, T1>;
772
773 /// Run the provided closure targeting this architecture with two additional arguments
774 /// and an inlining hint.
775 ///
776 /// This function is always safe to call, but the function `f` likely needs to be
777 /// inlined into `run` for the correct target features to be applied.
778 ///
779 /// Note that although an inline hint is applied, it is not guaranteed that this call
780 /// will be inlined due to the interaction of `target_features`. If you really need `F`
781 /// to be inlined, you can call its `Target2` method directly, but care must be taken
782 /// because this will not reapply `target_features`.
783 fn run2_inline<F, T0, T1, R>(self, f: F, x0: T0, x1: T1) -> R
784 where
785 F: Target2<Self, R, T0, T1>;
786
787 /// Run the provided closure targeting this architecture with three additional arguments.
788 ///
789 /// This function is always safe to call, but the function `f` likely needs to be
790 /// inlined into `run` for the correct target features to be applied.
791 fn run3<F, T0, T1, T2, R>(self, f: F, x0: T0, x1: T1, x2: T2) -> R
792 where
793 F: Target3<Self, R, T0, T1, T2>;
794
795 /// Run the provided closure targeting this architecture with three additional arguments
796 /// and an inlining hint.
797 ///
798 /// This function is always safe to call, but the function `f` likely needs to be
799 /// inlined into `run` for the correct target features to be applied.
800 ///
801 /// Note that although an inline hint is applied, it is not guaranteed that this call
802 /// will be inlined due to the interaction of `target_features`. If you really need `F`
803 /// to be inlined, you can call its `Target3` method directly, but care must be taken
804 /// because this will not reapply `target_features`.
805 fn run3_inline<F, T0, T1, T2, R>(self, f: F, x0: T0, x1: T1, x2: T2) -> R
806 where
807 F: Target3<Self, R, T0, T1, T2>;
808
809 /// Return a function pointer invoking [`FTarget1::run`] with `self` as the architecture.
810 /// ```
811 /// use diskann_wide::{Architecture, arch::FTarget1};
812 ///
813 /// struct Square;
814 ///
815 /// impl<A: Architecture> FTarget1<A, f32, f32> for Square
816 /// {
817 /// fn run(_: A, x: f32) -> f32 {
818 /// x * x
819 /// }
820 /// }
821 ///
822 /// let f = (diskann_wide::ARCH).dispatch1::<Square, f32, f32>();
823 /// assert_eq!(f.call(10.0), 100.0);
824 /// ```
825 fn dispatch1<F, R, T0>(self) -> Dispatched1<R, T0>
826 where
827 T0: AddLifetime,
828 F: for<'a> FTarget1<Self, R, T0::Of<'a>>;
829
830 /// Return a function pointer invoking [`FTarget2::run`] with `self` as the architecture.
831 /// ```
832 /// use diskann_wide::{
833 /// Architecture,
834 /// arch::FTarget2,
835 /// lifetime::{Mut, Ref},
836 /// };
837 ///
838 /// struct Copy;
839 ///
840 /// // Copy a slice and return the number of elements.
841 /// impl<A: Architecture> FTarget2<A, usize, &mut [f32], &[f32]> for Copy
842 /// {
843 /// fn run(_: A, dst: &mut [f32], src: &[f32]) -> usize {
844 /// dst.copy_from_slice(src);
845 /// src.len()
846 /// }
847 /// }
848 ///
849 /// let f = (diskann_wide::ARCH).dispatch2::<Copy, usize, Mut<[f32]>, Ref<[f32]>>();
850 /// let src = [1.0, 2.0, 3.0];
851 /// let mut dst = [0.0f32; 3];
852 /// assert_eq!(f.call(&mut dst, &src), 3);
853 /// assert_eq!(src, dst);
854 /// ```
855 fn dispatch2<F, R, T0, T1>(self) -> Dispatched2<R, T0, T1>
856 where
857 T0: AddLifetime,
858 T1: AddLifetime,
859 F: for<'a, 'b> FTarget2<Self, R, T0::Of<'a>, T1::Of<'b>>;
860
861 /// Return a function pointer invoking [`FTarget3::run`] with `self` as the architecture.
862 /// ```
863 /// use diskann_wide::{
864 /// Architecture,
865 /// arch::FTarget3,
866 /// lifetime::Ref,
867 /// };
868 ///
869 /// struct Sum;
870 ///
871 /// // Return the sum of the three arguments.
872 /// impl<A: Architecture> FTarget3<A, f32, &f32, &f32, &f32> for Sum
873 /// {
874 /// fn run(_: A, x: &f32, y: &f32, z: &f32) -> f32 {
875 /// x + y + z
876 /// }
877 /// }
878 ///
879 /// let f = (diskann_wide::ARCH).dispatch3::<Sum, f32, Ref<f32>, Ref<f32>, Ref<f32>>();
880 /// assert_eq!(f.call(&1.0, &2.0, &3.0), 6.0);
881 /// ```
882 fn dispatch3<F, R, T0, T1, T2>(self) -> Dispatched3<R, T0, T1, T2>
883 where
884 T0: AddLifetime,
885 T1: AddLifetime,
886 T2: AddLifetime,
887 F: for<'a, 'b, 'c> FTarget3<Self, R, T0::Of<'a>, T1::Of<'b>, T2::Of<'c>>;
888}
889
890/// A functor that targets a particular architecture, accepting no additional arguments.
891pub trait Target<A, R>
892where
893 A: Architecture,
894{
895 /// Run the operation with the provided `Architecture`.
896 fn run(self, arch: A) -> R;
897}
898
899/// A functor that targets a particular architecture, accepting one additional arguments.
900pub trait Target1<A, R, T0>
901where
902 A: Architecture,
903{
904 /// Run the operation with the provided `Architecture`.
905 fn run(self, arch: A, x0: T0) -> R;
906}
907
908/// A functor that targets a particular architecture, accepting two additional arguments.
909pub trait Target2<A, R, T0, T1>
910where
911 A: Architecture,
912{
913 /// Run the operation with the provided `Architecture`.
914 fn run(self, arch: A, x0: T0, x1: T1) -> R;
915}
916
917/// A functor that targets a particular architecture, accepting three additional arguments.
918pub trait Target3<A, R, T0, T1, T2>
919where
920 A: Architecture,
921{
922 /// Run the operation with the provided `Architecture`.
923 fn run(self, arch: A, x0: T0, x1: T1, x2: T2) -> R;
924}
925
926/// A variation of [`Target1`] that uses an associated function instead of a method.
927///
928/// This is used in the function pointer API.
929pub trait FTarget1<A, R, T0>
930where
931 A: Architecture,
932{
933 fn run(arch: A, x0: T0) -> R;
934}
935
936/// A variation of [`Target2`] that uses an associated function instead of a method.
937///
938/// This is used in the function pointer API.
939pub trait FTarget2<A, R, T0, T1>
940where
941 A: Architecture,
942{
943 fn run(arch: A, x0: T0, x1: T1) -> R;
944}
945
946/// A variation of [`Target3`] that uses an associated function instead of a method.
947///
948/// This is used in the function pointer API.
949pub trait FTarget3<A, R, T0, T1, T2>
950where
951 A: Architecture,
952{
953 fn run(arch: A, x0: T0, x1: T1, x2: T2) -> R;
954}
955
956/// Run the closure with code-generated for the specified architecture.
957///
958/// Note that if the body of the closure is not inlined, this will likely have no effect.
959impl<A, R, F> Target<A, R> for F
960where
961 A: Architecture,
962 F: FnOnce() -> R,
963{
964 #[inline]
965 fn run(self, _: A) -> R {
966 (self)()
967 }
968}
969
970/// Run the closure with code-generated for the specified architecture.
971///
972/// Note that if the body of the closure is not inlined, this will likely have no effect.
973impl<A, R, T0, F> Target1<A, R, T0> for F
974where
975 A: Architecture,
976 F: FnOnce(T0) -> R,
977{
978 #[inline]
979 fn run(self, _: A, x0: T0) -> R {
980 (self)(x0)
981 }
982}
983
984/// Run the closure with code-generated for the specified architecture.
985///
986/// Note that if the body of the closure is not inlined, this will likely have no effect.
987impl<A, R, T0, T1, F> Target2<A, R, T0, T1> for F
988where
989 A: Architecture,
990 F: FnOnce(T0, T1) -> R,
991{
992 #[inline]
993 fn run(self, _: A, x0: T0, x1: T1) -> R {
994 (self)(x0, x1)
995 }
996}
997
998/// Run the closure with code-generated for the specified architecture.
999///
1000/// Note that if the body of the closure is not inlined, this will likely have no effect.
1001impl<A, R, T0, T1, T2, F> Target3<A, R, T0, T1, T2> for F
1002where
1003 A: Architecture,
1004 F: FnOnce(T0, T1, T2) -> R,
1005{
1006 #[inline]
1007 fn run(self, _: A, x0: T0, x1: T1, x2: T2) -> R {
1008 (self)(x0, x1, x2)
1009 }
1010}
1011
1012/// A hidden architecture for use in the function pointer API.
1013#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
1014struct Hidden;
1015
1016const _ASSERT_ZST: () = assert!(
1017 std::mem::size_of::<Hidden>() == 0,
1018 "Hidden **must** be zero sized"
1019);
1020
1021const _ASSERT_ALIGNED: () = assert!(
1022 std::mem::align_of::<Hidden>() == 1,
1023 "Hidden **must** be alignment 1"
1024);
1025
1026macro_rules! dispatched {
1027 ($name:ident, { $($Ts:ident )* }, { $($xs:ident )* }, { $($lt:lifetime )* }) => {
1028 /// A function pointer that calls directly into a micro-architecture optimized
1029 /// function, returning a value of type `R` and accepting the specified number of
1030 /// arguments.
1031 ///
1032 /// Arguments are mapped using the [`AddLifetime`] trait to enable passing structs
1033 /// with up to a single non-`'static` lifetime parameter.
1034 ///
1035 /// This type is guaranteed:
1036 ///
1037 /// * To have the same size as a regular function pointer.
1038 /// * To have the same ABI as a regular function pointer.
1039 /// * To use the null-pointer optimization - so `Option<Self>` has the same size as
1040 /// `Self`.
1041 #[derive(Debug)]
1042 #[repr(transparent)]
1043 pub struct $name<R, $($Ts,)*>
1044 where
1045 $($Ts: AddLifetime,)*
1046 {
1047 f: for<$($lt,)*> unsafe fn(Hidden, $($Ts::Of<$lt>,)*) -> R,
1048 }
1049
1050 impl<R, $($Ts,)*> $name<R, $($Ts,)*>
1051 where
1052 $($Ts: AddLifetime,)*
1053 {
1054 /// Construct a new safe instance of `Self` around the raw function pointer.
1055 ///
1056 /// # Safety
1057 ///
1058 /// The caller asserts that the runtime implementation of `f` is safe to call
1059 /// on the current CPU target.
1060 ///
1061 /// Usually, this means that the runtime CPU has the target features required
1062 /// for the destination of `f`.
1063 unsafe fn new(f: unsafe fn(Hidden, $($Ts::Of<'_>,)*) -> R) -> Self {
1064 Self { f }
1065 }
1066
1067 /// Invoke the function with the lifetime-annotated arguments and return the
1068 /// result.
1069 ///
1070 /// The below example demonstrates the behavior for [`Dispatched1`], but the
1071 /// same logic applies to all variadic instances.
1072 #[inline(always)]
1073 pub fn call(self, $($xs: $Ts::Of<'_>,)*) -> R {
1074 // SAFETY: The constructor of `Dispatched` asserts the call is safe.
1075 unsafe { (self.f)(Hidden, $($xs,)*) }
1076 }
1077 }
1078
1079 impl<R, $($Ts,)*> Clone for $name<R, $($Ts,)*>
1080 where
1081 $($Ts: AddLifetime,)*
1082 {
1083 fn clone(&self) -> Self {
1084 *self
1085 }
1086 }
1087
1088 impl<R, $($Ts,)*> Copy for $name<R, $($Ts,)*>
1089 where
1090 $($Ts: AddLifetime,)*
1091 {
1092 }
1093 }
1094}
1095
1096dispatched!(Dispatched1, { T0 }, { x0 }, { 'a0 });
1097dispatched!(Dispatched2, { T0 T1 }, { x0 x1 }, { 'a0 'a1 });
1098dispatched!(Dispatched3, { T0 T1 T2 }, { x0 x1 x2 }, { 'a0 'a1 'a2 });
1099
1100/// This macro stamps out the function-pointer transmute trick we use to type-erase
1101/// architecture in the function-pointer API.
1102macro_rules! hide {
1103 ($name:ident, $dispatched:ident, { $($Ts:ident )* }) => {
1104 /// Construct a new instance of [`Self`] from the raw function pointer.
1105 ///
1106 /// # Safety
1107 ///
1108 /// Internally, we will transmute the zero-sized type `A` to another hidden type
1109 /// to erase the architecture information of the dispatched function pointer.
1110 ///
1111 /// This function will not compile if `A` is not zero sized nor has an alignment
1112 /// of 1.
1113 ///
1114 /// We can do this because Rust guarantees that zero sized types are ABI
1115 /// compatible.
1116 ///
1117 /// The caller must ensure that winking into existence an instance of `A` is
1118 /// a safe operation. For [`Architectures`], this means that the requirements
1119 /// of `A::new()` are upheld.
1120 ///
1121 /// Put plainly:
1122 ///
1123 /// 1. The runtime CPU must support the target features required by `A`.
1124 /// 2. `A` **must** be a zero-sized type with an alignment of 1.
1125 unsafe fn $name<A, R, $($Ts,)*>(
1126 f: unsafe fn(A, $($Ts::Of<'_>,)*) -> R
1127 ) -> $dispatched<R, $($Ts,)*>
1128 where
1129 $($Ts: AddLifetime,)*
1130 {
1131 // Check that `A` is a zero sized type.
1132 const {
1133 assert!(
1134 std::mem::size_of::<A>() == 0,
1135 "A must be zero sized to be ABI compatible with `Hidden`"
1136 )
1137 };
1138
1139 // Check that `A` has alignment 1.
1140 const {
1141 assert!(
1142 std::mem::align_of::<A>() == 1,
1143 "A must have an alignment of 1 to be ABI compatible with `Hidden`"
1144 )
1145 };
1146
1147 // SAFETY: The transmute is safe because `Hidden` and `A` are both zero
1148 // sized types with alignment 1, which are ABI compatible.
1149 //
1150 // All the rest of the arguments are untouched.
1151 //
1152 // The caller asserts that it is safe to call this function pointer.
1153 let f = unsafe {
1154 std::mem::transmute::<
1155 unsafe fn(A, $($Ts::Of<'_>,)*) -> R,
1156 unsafe fn(Hidden, $($Ts::Of<'_>,)*) -> R
1157 >(f)
1158 };
1159
1160 // SAFETY: The caller asserts that it is safe to call `f`.
1161 unsafe { $dispatched::new(f) }
1162 }
1163 }
1164}
1165
1166hide!(hide1, Dispatched1, { T0 });
1167hide!(hide2, Dispatched2, { T0 T1 });
1168hide!(hide3, Dispatched3, { T0 T1 T2 });
1169
1170// Macros to help implement architectures.
1171
1172macro_rules! maskdef {
1173 ($mask:ident = $repr:ty) => {
1174 type $mask = <$repr as SIMDVector>::Mask;
1175 };
1176 ($($mask:ident = $repr:ty),+ $(,)?) => {
1177 $($crate::arch::maskdef!($mask = $repr);)+
1178 };
1179 () => {
1180 $crate::arch::maskdef!(
1181 mask_f16x8 = f16x8,
1182 mask_f16x16 = f16x16,
1183
1184 mask_f32x4 = f32x4,
1185 mask_f32x8 = f32x8,
1186 mask_f32x16 = f32x16,
1187
1188 mask_i8x16 = i8x16,
1189 mask_i8x32 = i8x32,
1190 mask_i8x64 = i8x64,
1191
1192 mask_i16x8 = i16x8,
1193 mask_i16x16 = i16x16,
1194 mask_i16x32 = i16x32,
1195
1196 mask_i32x4 = i32x4,
1197 mask_i32x8 = i32x8,
1198 mask_i32x16 = i32x16,
1199
1200 mask_u8x16 = u8x16,
1201 mask_u8x32 = u8x32,
1202 mask_u8x64 = u8x64,
1203
1204 mask_u32x4 = u32x4,
1205 mask_u32x8 = u32x8,
1206 mask_u32x16 = u32x16,
1207
1208 mask_u64x2 = u64x2,
1209 mask_u64x4 = u64x4,
1210 );
1211 };
1212}
1213
1214macro_rules! typedef {
1215 () => {
1216 $crate::arch::typedef!(
1217 f16x8,
1218 f16x16,
1219
1220 f32x4,
1221 f32x8,
1222 f32x16,
1223
1224 i8x16,
1225 i8x32,
1226 i8x64,
1227
1228 i16x8,
1229 i16x16,
1230 i16x32,
1231
1232 i32x4,
1233 i32x8,
1234 i32x16,
1235
1236 u8x16,
1237 u8x32,
1238 u8x64,
1239
1240 u32x4,
1241 u32x8,
1242 u32x16,
1243
1244 u64x2,
1245 u64x4,
1246 );
1247 };
1248 ($repr:ident) => {
1249 type $repr = $repr;
1250 };
1251 ($($repr:ident),+ $(,)?) => {
1252 $($crate::arch::typedef!($repr);)+
1253 };
1254}
1255
1256pub(crate) use maskdef;
1257pub(crate) use typedef;
1258
1259///////////
1260// Tests //
1261///////////
1262
1263// All tests here **must** run successfully under Miri.
1264#[cfg(test)]
1265mod tests {
1266 use super::*;
1267 use crate::lifetime::{Mut, Ref};
1268
1269 struct TestOp;
1270
1271 // Returns a static string.
1272 impl<A> Target<A, &'static str> for TestOp
1273 where
1274 A: Architecture,
1275 {
1276 fn run(self, _: A) -> &'static str {
1277 "hello world"
1278 }
1279 }
1280
1281 // Simply add all elements in the array and return the result.
1282 impl<A> Target1<A, f32, &[f32]> for TestOp
1283 where
1284 A: Architecture,
1285 {
1286 fn run(self, _: A, x: &[f32]) -> f32 {
1287 x.iter().sum()
1288 }
1289 }
1290
1291 impl<A> FTarget1<A, f32, &[f32]> for TestOp
1292 where
1293 A: Architecture,
1294 {
1295 fn run(arch: A, x: &[f32]) -> f32 {
1296 <_ as Target1<_, _, _>>::run(Self, arch, x)
1297 }
1298 }
1299
1300 // Both perform a sum and copy the results into the destination.
1301 impl<A> Target2<A, f32, &mut [f32], &[f32]> for TestOp
1302 where
1303 A: Architecture,
1304 {
1305 fn run(self, _: A, x: &mut [f32], y: &[f32]) -> f32 {
1306 x.copy_from_slice(y);
1307 y.iter().sum()
1308 }
1309 }
1310
1311 impl<A> FTarget2<A, f32, &mut [f32], &[f32]> for TestOp
1312 where
1313 A: Architecture,
1314 {
1315 fn run(arch: A, x: &mut [f32], y: &[f32]) -> f32 {
1316 <_ as Target2<_, _, _, _>>::run(TestOp, arch, x, y)
1317 }
1318 }
1319
1320 impl<A> Target3<A, f32, &mut [f32], &[f32], f32> for TestOp
1321 where
1322 A: Architecture,
1323 {
1324 fn run(self, _: A, x: &mut [f32], y: &[f32], z: f32) -> f32 {
1325 assert_eq!(x.len(), y.len());
1326 x.iter_mut().zip(y.iter()).for_each(|(d, s)| *d = *s + z);
1327 y.iter().sum()
1328 }
1329 }
1330
1331 impl<A> FTarget3<A, f32, &mut [f32], &[f32], f32> for TestOp
1332 where
1333 A: Architecture,
1334 {
1335 fn run(arch: A, x: &mut [f32], y: &[f32], z: f32) -> f32 {
1336 <_ as Target3<_, _, _, _, _>>::run(TestOp, arch, x, y, z)
1337 }
1338 }
1339
1340 //------------------//
1341 // Target Interface //
1342 //------------------//
1343
1344 #[test]
1345 fn zero_arg_target() {
1346 let expected = "hello world";
1347 assert_eq!((Scalar).run(TestOp), expected);
1348 assert_eq!((Scalar).run_inline(TestOp), expected);
1349
1350 #[cfg(target_arch = "x86_64")]
1351 if let Some(arch) = x86_64::V3::new_checked_uncached() {
1352 assert_eq!(arch.run(TestOp), expected);
1353 assert_eq!(arch.run_inline(TestOp), expected);
1354 }
1355
1356 #[cfg(target_arch = "x86_64")]
1357 if let Some(arch) = x86_64::V4::new_checked_miri() {
1358 assert_eq!(arch.run(TestOp), expected);
1359 assert_eq!(arch.run_inline(TestOp), expected);
1360 }
1361
1362 #[cfg(target_arch = "aarch64")]
1363 if let Some(arch) = aarch64::Neon::new_checked() {
1364 assert_eq!(arch.run(TestOp), expected);
1365 assert_eq!(arch.run_inline(TestOp), expected);
1366 }
1367 }
1368
1369 #[test]
1370 fn one_arg_target() {
1371 let src = [1.0f32, 2.0f32, 3.0f32];
1372 let sum: f32 = src.iter().sum();
1373
1374 assert_eq!((Scalar).run1(TestOp, &src), sum);
1375 assert_eq!((Scalar).run1_inline(TestOp, &src), sum);
1376
1377 #[cfg(target_arch = "x86_64")]
1378 if let Some(arch) = x86_64::V3::new_checked_uncached() {
1379 assert_eq!(arch.run1(TestOp, &src), sum);
1380 assert_eq!(arch.run1_inline(TestOp, &src), sum);
1381 }
1382
1383 #[cfg(target_arch = "x86_64")]
1384 if let Some(arch) = x86_64::V4::new_checked_miri() {
1385 assert_eq!(arch.run1(TestOp, &src), sum);
1386 assert_eq!(arch.run1_inline(TestOp, &src), sum);
1387 }
1388
1389 #[cfg(target_arch = "aarch64")]
1390 if let Some(arch) = aarch64::Neon::new_checked() {
1391 assert_eq!(arch.run1(TestOp, &src), sum);
1392 assert_eq!(arch.run1_inline(TestOp, &src), sum);
1393 }
1394 }
1395
1396 #[test]
1397 fn two_arg_target() {
1398 let src = [1.0f32, 2.0f32, 3.0f32];
1399 let sum: f32 = src.iter().sum();
1400
1401 macro_rules! gen_test {
1402 ($arch:ident) => {{
1403 let mut dst = [0.0f32; 3];
1404 assert_eq!($arch.run2(TestOp, &mut dst, &src), sum);
1405 assert_eq!(dst, src);
1406 }
1407
1408 {
1409 let mut dst = [0.0f32; 3];
1410 assert_eq!($arch.run2_inline(TestOp, &mut dst, &src), sum);
1411 assert_eq!(dst, src);
1412 }};
1413 }
1414
1415 gen_test!(Scalar);
1416
1417 #[cfg(target_arch = "x86_64")]
1418 if let Some(arch) = x86_64::V3::new_checked_uncached() {
1419 gen_test!(arch);
1420 }
1421
1422 #[cfg(target_arch = "x86_64")]
1423 if let Some(arch) = x86_64::V4::new_checked_miri() {
1424 gen_test!(arch);
1425 }
1426
1427 #[cfg(target_arch = "aarch64")]
1428 if let Some(arch) = aarch64::Neon::new_checked() {
1429 gen_test!(arch);
1430 }
1431 }
1432
1433 #[test]
1434 fn three_arg_target() {
1435 let src = [1.0f32, 2.0f32, 3.0f32];
1436 let sum: f32 = src.iter().sum();
1437 let offset = 10.0f32;
1438 let expected = [11.0f32, 12.0f32, 13.0f32];
1439
1440 macro_rules! gen_test {
1441 ($arch:ident) => {{
1442 let mut dst = [0.0f32; 3];
1443 assert_eq!($arch.run3(TestOp, &mut dst, &src, offset), sum);
1444 assert_eq!(dst, expected);
1445 }
1446
1447 {
1448 let mut dst = [0.0f32; 3];
1449 assert_eq!($arch.run3_inline(TestOp, &mut dst, &src, offset), sum);
1450 assert_eq!(dst, expected);
1451 }};
1452 }
1453
1454 gen_test!(Scalar);
1455
1456 #[cfg(target_arch = "x86_64")]
1457 if let Some(arch) = x86_64::V3::new_checked_uncached() {
1458 gen_test!(arch);
1459 }
1460
1461 #[cfg(target_arch = "x86_64")]
1462 if let Some(arch) = x86_64::V4::new_checked_miri() {
1463 gen_test!(arch);
1464 }
1465
1466 #[cfg(target_arch = "aarch64")]
1467 if let Some(arch) = aarch64::Neon::new_checked() {
1468 gen_test!(arch);
1469 }
1470 }
1471
1472 //----------------------------//
1473 // Function Pointer Interface //
1474 //----------------------------//
1475
1476 #[test]
1477 fn one_arg_function_pointer() {
1478 let src = [1.0f32, 2.0f32, 3.0f32];
1479 let sum: f32 = src.iter().sum();
1480
1481 type FnPtr = Dispatched1<f32, Ref<[f32]>>;
1482
1483 assert_eq!(std::mem::size_of::<FnPtr>(), std::mem::size_of::<fn()>());
1484 assert_eq!(
1485 std::mem::size_of::<Option<FnPtr>>(),
1486 std::mem::size_of::<fn()>()
1487 );
1488
1489 {
1490 let f: FnPtr = (Scalar).dispatch1::<TestOp, f32, Ref<[f32]>>();
1491 assert_eq!(f.call(&src), sum);
1492 }
1493
1494 #[cfg(target_arch = "x86_64")]
1495 if let Some(arch) = x86_64::V3::new_checked_uncached() {
1496 let f: FnPtr = arch.dispatch1::<TestOp, f32, Ref<[f32]>>();
1497 assert_eq!(f.call(&src), sum);
1498 }
1499
1500 #[cfg(target_arch = "x86_64")]
1501 if let Some(arch) = x86_64::V4::new_checked_miri() {
1502 let f: FnPtr = arch.dispatch1::<TestOp, f32, Ref<[f32]>>();
1503 assert_eq!(f.call(&src), sum);
1504 }
1505
1506 #[cfg(target_arch = "aarch64")]
1507 if let Some(arch) = aarch64::Neon::new_checked() {
1508 let f: FnPtr = arch.dispatch1::<TestOp, f32, Ref<[f32]>>();
1509 assert_eq!(f.call(&src), sum);
1510 }
1511 }
1512
1513 #[test]
1514 fn two_arg_function_pointer() {
1515 let src = [1.0f32, 2.0f32, 3.0f32];
1516 let sum: f32 = src.iter().sum();
1517
1518 type FnPtr = Dispatched2<f32, Mut<[f32]>, Ref<[f32]>>;
1519
1520 assert_eq!(std::mem::size_of::<FnPtr>(), std::mem::size_of::<fn()>());
1521 assert_eq!(
1522 std::mem::size_of::<Option<FnPtr>>(),
1523 std::mem::size_of::<fn()>()
1524 );
1525
1526 {
1527 let mut dst = [0.0f32; 3];
1528 let f: FnPtr = (Scalar).dispatch2::<TestOp, f32, Mut<[f32]>, Ref<[f32]>>();
1529 assert_eq!(f.call(&mut dst, &src), sum);
1530 assert_eq!(dst, src);
1531 }
1532
1533 #[cfg(target_arch = "x86_64")]
1534 if let Some(arch) = x86_64::V3::new_checked_uncached() {
1535 let mut dst = [0.0f32; 3];
1536 let f: FnPtr = arch.dispatch2::<TestOp, f32, Mut<[f32]>, Ref<[f32]>>();
1537 assert_eq!(f.call(&mut dst, &src), sum);
1538 assert_eq!(dst, src);
1539 }
1540
1541 #[cfg(target_arch = "x86_64")]
1542 if let Some(arch) = x86_64::V4::new_checked_miri() {
1543 let mut dst = [0.0f32; 3];
1544 let f: FnPtr = arch.dispatch2::<TestOp, f32, Mut<[f32]>, Ref<[f32]>>();
1545 assert_eq!(f.call(&mut dst, &src), sum);
1546 assert_eq!(dst, src);
1547 }
1548
1549 #[cfg(target_arch = "aarch64")]
1550 if let Some(arch) = aarch64::Neon::new_checked() {
1551 let mut dst = [0.0f32; 3];
1552 let f: FnPtr = arch.dispatch2::<TestOp, f32, Mut<[f32]>, Ref<[f32]>>();
1553 assert_eq!(f.call(&mut dst, &src), sum);
1554 assert_eq!(dst, src);
1555 }
1556 }
1557
1558 #[test]
1559 fn three_arg_function_pointer() {
1560 let src = [1.0f32, 2.0f32, 3.0f32];
1561 let sum: f32 = src.iter().sum();
1562 let offset = 10.0f32;
1563 let expected = [11.0f32, 12.0f32, 13.0f32];
1564
1565 type FnPtr = Dispatched3<f32, Mut<[f32]>, Ref<[f32]>, f32>;
1566
1567 assert_eq!(std::mem::size_of::<FnPtr>(), std::mem::size_of::<fn()>());
1568 assert_eq!(
1569 std::mem::size_of::<Option<FnPtr>>(),
1570 std::mem::size_of::<fn()>()
1571 );
1572
1573 {
1574 let mut dst = [0.0f32; 3];
1575 let f: FnPtr = (Scalar).dispatch3::<TestOp, f32, Mut<[f32]>, Ref<[f32]>, f32>();
1576 assert_eq!(f.call(&mut dst, &src, offset), sum);
1577 assert_eq!(dst, expected);
1578 }
1579
1580 #[cfg(target_arch = "x86_64")]
1581 if let Some(arch) = x86_64::V3::new_checked_uncached() {
1582 let mut dst = [0.0f32; 3];
1583 let f: FnPtr = arch.dispatch3::<TestOp, f32, Mut<[f32]>, Ref<[f32]>, f32>();
1584 assert_eq!(f.call(&mut dst, &src, offset), sum);
1585 assert_eq!(dst, expected);
1586 }
1587
1588 #[cfg(target_arch = "x86_64")]
1589 if let Some(arch) = x86_64::V4::new_checked_miri() {
1590 let mut dst = [0.0f32; 3];
1591 let f: FnPtr = arch.dispatch3::<TestOp, f32, Mut<[f32]>, Ref<[f32]>, f32>();
1592 assert_eq!(f.call(&mut dst, &src, offset), sum);
1593 assert_eq!(dst, expected);
1594 }
1595
1596 #[cfg(target_arch = "aarch64")]
1597 if let Some(arch) = aarch64::Neon::new_checked() {
1598 let mut dst = [0.0f32; 3];
1599 let f: FnPtr = arch.dispatch3::<TestOp, f32, Mut<[f32]>, Ref<[f32]>, f32>();
1600 assert_eq!(f.call(&mut dst, &src, offset), sum);
1601 assert_eq!(dst, expected);
1602 }
1603 }
1604}