Skip to main content

diskann_wide/arch/
mod.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! Traits and functions supporting multi-architecture applications.
7//!
8//! Many SIMD instructions are micro-architecture specific, meaning that only a subset of
9//! CPUs found in the wild can support SIMD accelerated algorithms. This module provides
10//! tools for writing SIMD algorithms supporting multiple architectures and provides a
11//! light-weight runtime dispatching service to select the most appropriate implementation
12//! at run time.
13//!
14//! The example code below demonstrates a multi-versioned `X = X + Y` kernel:
15//! ```rust
16//! use diskann_wide::arch::{Target2, dispatch2};
17//!
18//! // A zero-sized type that we can use to implement a trait.
19//! struct Add;
20//!
21//! impl<A: diskann_wide::Architecture> Target2<A, (), &mut [f32], &[f32]> for Add {
22//!     #[inline]
23//!     fn run(self, _: A, dst: &mut [f32], src: &[f32]) {
24//!         std::iter::zip(dst.iter_mut(), src.iter()).for_each(|(d, s)| *d += *s);
25//!     }
26//! }
27//!
28//! fn add(dst: &mut [f32], src: &[f32]) {
29//!     dispatch2(Add, dst, src)
30//! }
31//!
32//! let mut dst = vec![1.0, 2.0, 3.0];
33//! add(&mut dst, &[2.0, 3.0, 4.0]);
34//! assert_eq!(dst, &[3.0, 5.0, 7.0]);
35//! ```
36//!
37//! Lets break down what's happening.
38//!
39//! The function [`dispatch2`] (suffixed with "2" because it takes two arguments, more on
40//! this later) takes the struct `Add`, which is required to implement [`Target2`] for all
41//! supported micro-architecture levels supported by `wide` for compilation target CPU.
42//!
43//! It then will determine at run time the features supported by the current CPU and invoke
44//! `Add::run` with the best architecture. The above example does not use any explicit SIMD
45//! and is generic with respect to the [`Architecture`]. This still allows the compiler to
46//! perform auto-vectorization for different platforms, which can still result in a speed-up.
47//!
48//! ## Mechanics of Dispatching
49//!
50//! Run time architecture detection happens only once (modulo race conditions) and once
51//! resolved involves an atomic load and a branch.
52//!
53//! ## Variadic traits and ABI
54//!
55//! The traits like [`Target1`] or [`dispatch2`] are suffixed by the number of additional
56//! arguments that they take. This is important for cases where we want the calling
57//! convention of the dispatched-to function to match the calling-convention of the
58//! dispatcher function. In this case, Rust will invoke the dispatched-to function using a
59//! jump instead of a function call.
60//!
61//! For example, if we had implemented the add method above without these variadics:
62//! ```rust
63//! use diskann_wide::arch::{Target, dispatch};
64//!
65//! struct AddV2<'a>(&'a mut [f32], &'a [f32]);
66//!
67//! impl<A: diskann_wide::Architecture> Target<A, ()> for AddV2<'_> {
68//!     #[inline]
69//!     fn run(self, _: A) {
70//!         std::iter::zip(self.0.iter_mut(), self.1.iter()).for_each(|(d, s)| *d += *s);
71//!     }
72//! }
73//!
74//! #[inline(never)]
75//! fn add_v2(dst: &mut [f32], src: &[f32]) {
76//!     dispatch(AddV2(dst, src))
77//! }
78//!
79//! let mut dst = vec![1.0, 2.0, 3.0];
80//! add_v2(&mut dst, &[2.0, 3.0, 4.0]);
81//! assert_eq!(dst, &[3.0, 5.0, 7.0]);
82//! ```
83//! then the function still works, but the assembly code goes from something that looks like
84//! this
85//! ```asm
86//!        mov rax, qword ptr [rip + diskann_wide::arch::x86_64::ARCH_NUMBER@GOTPCREL]
87//!        mov rax, qword ptr [rax]
88//!        cmp rax, 1
89//!        je diskann_wide::arch::x86_64::V3::run_with_2
90//!        test rax, rax
91//!        je diskann_wide::arch::x86_64::dispatch_resolve2
92//!        <scalar-code>
93//! ```
94//! where there are no stack writes and the `V3` compatible code is reached via `jmp` to
95//! something that looks like
96//! ```asm
97//!        sub rsp, 56
98//!        mov rax, qword ptr fs:[40]
99//!        mov qword ptr [rsp + 48], rax
100//!        mov rax, qword ptr [rip + diskann_wide::arch::x86_64::ARCH_NUMBER@GOTPCREL]
101//!        mov rax, qword ptr [rax]
102//!        test rax, rax
103//!        je .LBB14_3
104//!        cmp rax, 1
105//!        jne .LBB14_2
106//!        mov qword ptr [rsp + 8], rdi
107//!        mov qword ptr [rsp + 16], 3
108//!        lea rax, [rip + .L__unnamed_8]
109//!        mov qword ptr [rsp + 24], rax
110//!        mov qword ptr [rsp + 32], 3
111//!        lea rax, [rsp + 7]
112//!        mov qword ptr [rsp + 40], rax
113//!        lea rdi, [rsp + 8]
114//!        call diskann_wide::arch::x86_64::V3::run_with
115//!        jmp .LBB14_5
116//!.LBB14_3:
117//!        mov qword ptr [rsp + 8], rdi
118//!        mov qword ptr [rsp + 16], 3
119//!        lea rax, [rip + .L__unnamed_8]
120//!        mov qword ptr [rsp + 24], rax
121//!        mov qword ptr [rsp + 32], 3
122//!        lea rdi, [rsp + 8]
123//!        call diskann_wide::arch::x86_64::dispatch_resolve::<loop_example::AddV2, ()>
124//!        jmp .LBB14_5
125//!        <scalar-code>
126//! ```
127//! Notice some unconditional stack writes, stack preparation for calling the `V3` compatible
128//! code, and a `call` to run the code.
129//!
130//! What's happening is that Rust will not inline functions annotated with
131//! `target_feature(enable = "feature")]` into an incompatible context. Since `AddV3` exceeds
132//! 16 bytes, it must be passed on the stack (on Linux at least). Therefore, we have extra
133//! overhead of stack preparation to call the dispatch target.
134//!
135//! ## Function Pointer API
136//!
137//! The previous section discussed performing a dynamic dispatch at a single call site, but
138//! you need to pay the (admittedly small) overhead every time this function is called.
139//! Another approach would be to perform dispatch a single time by obtaining a function pointer
140//! to the dispatched function and then calling through that function pointer.
141//!
142//! This is a little tricky for two reasons (and a whole host of less obvious reasons).
143//!
144//! Reason 1: Functions with additional `target_features` cannot be inlined. This means the
145//! simple approach of
146//! ```rust
147//! use diskann_wide::{Architecture, arch::{Scalar, Target, Target2, dispatch2}};
148//!
149//! // A zero-sized type that we can use to implement a trait.
150//! struct Add;
151//!
152//! impl<A: Architecture> Target2<A, (), &mut [f32], &[f32]> for Add {
153//!     #[inline]
154//!     fn run(self, _: A, dst: &mut [f32], src: &[f32]) {
155//!         std::iter::zip(dst.iter_mut(), src.iter()).for_each(|(d, s)| *d += *s);
156//!     }
157//! }
158//!
159//! impl<A: Architecture> Target<A, fn(A, &mut [f32], &[f32])> for Add {
160//!    fn run(self, arch: A) -> fn(A, &mut [f32], &[f32]) {
161//!        // Create a non-capturing closure that invokes `arch.run2`.
162//!        //
163//!        // The invocation of `arch.run2` will apply the necessary target features and
164//!        // the non-capturing closure can be coerced into a function pointer.
165//!        let f = |arch: A, dst: &mut [f32], src: &[f32]| arch.run2(Add, dst, src);
166//!        f
167//!    }
168//! }
169//!
170//! let f: fn(Scalar, &mut [f32], &[f32]) = (Scalar).run(Add);
171//!
172//! let mut dst = vec![1.0, 2.0, 3.0];
173//! f(Scalar, &mut dst, &[2.0, 3.0, 4.0]);
174//! assert_eq!(dst, &[3.0, 5.0, 7.0]);
175//! ```
176//! would likely generate code that looks something like:
177//! ```asm
178//! .section .text.<<diskann_wide::Add as diskann_wide::arch::Target<_, _>::run::{closure#0} /* snip */>>
179//!        .p2align        4
180//!.type   <<diskann_wide::Add as _>::call_once,@function
181//!<<diskann_wide::Add as _>>::call_once:
182//!        .cfi_startproc
183//!        jmp <diskann_wide::arch::Scalar>::run_with_2::<diskann_wide::Add, &mut [f32], &[f32], ()>
184//! ```
185//! The body is simply an unconditional jump to the actual implementation precisely because
186//! the actual implementation cannot be inlined into the body of the closure we coerced into
187//! a function pointer. Unfortunately, the same applies to most other ways one would try to
188//! create a function pointer to the dispatched-to function.
189//!
190//! The consequence of this is that we need to take an **unsafe** function pointer so we
191//! can dispatch call directly to the implementation.
192//!
193//! Reason 2: Even if the above approach worked, the [`Architecture`] is still present in the
194//! signature of the `fn`, meaning we haven't really hidden the micro-architecture
195//! information.
196//!
197//! With that in mind, the current solution looks like the following.
198//! ```rust
199//! use diskann_wide::{
200//!     Architecture,
201//!     arch::{self, Dispatched2},
202//!     lifetime::{Ref, Mut},
203//! };
204//!
205//! struct Add;
206//!
207//! // Note the use of `FTarget` instead of `Target`. That is because the implementation is
208//! // simply an associated function instead of a method.
209//! impl<A: Architecture> arch::FTarget2<A, (), &mut [f32], &[f32]> for Add {
210//!     #[inline]
211//!     fn run(_: A, dst: &mut [f32], src: &[f32]) {
212//!         std::iter::zip(dst.iter_mut(), src.iter()).for_each(|(d, s)| *d += *s);
213//!     }
214//! }
215//!
216//! // The `Dispatched2` struct is a slightly magical wrapper around a function pointer,
217//! // returning the unit type `()` and taking two arguments; one `&mut [f32]` and
218//! // the other `&[f32]`.
219//! //
220//! // The need for `Mut` and `Ref` is described below.
221//! type FnPtr = Dispatched2<(), Mut<[f32]>, Ref<[f32]>>;
222//!
223//! impl<A: Architecture> arch::Target<A, FnPtr> for Add {
224//!    fn run(self, arch: A) -> FnPtr {
225//!        arch.dispatch2::<Self, (), Mut<[f32]>, Ref<[f32]>>()
226//!    }
227//! }
228//!
229//! let f: FnPtr = diskann_wide::arch::dispatch(Add);
230//! let mut dst = vec![1.0, 2.0, 3.0];
231//!
232//! // Invoke the function pointer
233//! f.call(&mut dst, &[2.0, 3.0, 4.0]);
234//! assert_eq!(dst, &[3.0, 5.0, 7.0]);
235//! ```
236//! The resulting function pointer (though "safely unsafe"), successfully hides that
237//! dispatched micro-architecture and when called will always go directly to the
238//! implementation rather than needing the trampoline.
239//!
240//! The [`Architecture`] methods
241//! * [`Architecture::dispatch1`], [`Architecture::dispatch2`], [`Architecture::dispatch3`]
242//!
243//! Will produce function pointers of the annotated arities
244//!
245//! * [`Dispatched1`], [`Dispatched2`], and [`Dispatched3`]
246//!
247//! and are accompanied by the generator traits
248//!
249//! * [`FTarget1`], [`FTarget2`], and [`FTarget3`]
250//!
251//! ### Lifetime Annotations and Limitations
252//!
253//! One thing to note in the above example is the use of [`crate::lifetime::Mut`] and
254//! [`crate::lifetime::Ref`] in the invocation of [`Architecture::dispatch2`]. This is an
255//! unfortunate limitation of the Rust compiler at the moment when it comes to inferring
256//! lifetimes of function pointers.
257//!
258//! For example, the following does not compile
259//! ```compile_fail
260//! pub struct Example;
261//!
262//! trait Run<T, U> {
263//!     fn run(x: T, y: U);
264//! }
265//!
266//! impl Example {
267//!     fn run<F, T, U>(self, x: T, y: U)
268//!     where
269//!         F: Run<T, U>,
270//!     {
271//!         F::run(x, y)
272//!     }
273//! }
274//!
275//! struct Add;
276//!
277//! impl Run<&mut [f32], &[f32]> for Add {
278//!     #[inline]
279//!     fn run(dst: &mut [f32], src: &[f32]) {
280//!         std::iter::zip(dst.iter_mut(), src.iter()).for_each(|(d, s)| *d += *s);
281//!     }
282//! }
283//!
284//! // This fails to compile! :(
285//! pub fn make() -> fn(Example, &mut [f32], &[f32]) {
286//!     let f = Example::run::<Add, &mut [f32], &[f32]>;
287//!     f
288//! }
289//! ```
290//! Fails to compile with the following error message
291//! ```text
292//! error[E0308]: mismatched types
293//!   --> <source>:27:5
294//!    |
295//! 25 | pub fn make() -> fn(Example, &mut [f32], &[f32]) {
296//!    |                  ------------------------------- expected `for<'a, 'b> fn(Example, &'a mut [f32], &'b [f32])` because of return type
297//! 26 |     let f = Example::run::<Add, &mut [f32], &[f32]>;
298//! 27 |     f
299//!    |     ^ one type is more general than the other
300//!    |
301//!    = note: expected fn pointer `for<'a, 'b> fn(Example, &'a mut _, &'b _)`
302//!                  found fn item `fn(Example, &mut _, &_) {Example::run::<Add, &mut [f32], &[f32]>}`
303//! ```
304//! The [`crate::lifetime::AddLifetime`] trait is the only solution the author found at time
305//! of writing that both avoids the trampoline the closure-like approaches induce and the above
306//! lifetime error. This leads to a few practical limitations:
307//!
308//! * Types passed through the function pointer interface can have at most a single lifetime.
309//! * The lifetimes of types passed through the function pointer interface must all be disjoint.
310//! * The return type cannot have a lifetime.
311//!
312//! Practically, these limitations are acceptable because micro-architecture dispatching is
313//! almost always the result of doing mathematical manipulation on primitive types and thus
314//! the lifetimes of the associated types are generally not complicated.
315//!
316//! ## Obtaining the [`Current`] Architecture
317//!
318//! When Rust crates are compiled, they can be provided with a target CPU. The function
319//! [`current`], the type [`Current`], and the constant [`crate::ARCH`] are all populated with
320//! the best matching wide [`Architecture`] selected at compile time.
321//!
322//! ## Hierarchies
323//!
324//! Each [`Architecture`] exposes a [`Level`] via [`Architecture::level()`] that
325//! can be used to compare capabilities without instantiating the architecture.
326//!
327//! ### X86
328//!
329//! * [`x86_64::V4`]: Supporting AVX-512 (and AVX2 and lower).
330//! * [`x86_64::V3`]: Supporting AVX2 and lower.
331//! * [`Scalar`]: Fallback architecture.
332//!
333//! The ordering is `Scalar` < `V3` < `V4`.
334//!
335//! ### Arm
336//!
337//! Currently, Arm support is limited to [`Scalar`].
338
339use half::f16;
340
341use crate::{
342    Const, SIMDCast, SIMDDotProduct, SIMDFloat, SIMDMask, SIMDSelect, SIMDSigned, SIMDSumTree,
343    SIMDUnsigned, SIMDVector, SplitJoin, ZipUnzip, lifetime::AddLifetime,
344};
345
346pub(crate) mod emulated;
347
348/// An [`Architecture`] that implements all operation as scalar loops, relying on the
349/// compiler for optimization.
350pub use emulated::Scalar;
351
352/// An opaque representation of an [`Architecture`]'s capability level.
353///
354/// `Level` allows comparing the relative capabilities of different architectures
355/// without requiring an instance of the architecture type. This is useful for
356/// compile-time checks against [`crate::ARCH`] where constructing architecture
357/// types like [`x86_64::V3`] would require `unsafe`.
358///
359/// Levels are totally ordered within an ISA family, with greater values indicating
360/// more capable instruction sets. [`Scalar`] is always the lowest level.
361///
362/// # Examples
363///
364/// Checking if the compile-time architecture meets a minimum capability:
365///
366/// ```
367/// #[cfg(target_arch = "x86_64")]
368/// use diskann_wide::{Architecture, arch};
369///
370/// // Check at compile time whether we were built with AVX2+ support.
371/// #[cfg(target_arch = "x86_64")]
372/// let _meets_v3 = arch::Current::level() >= arch::x86_64::V3::level();
373/// ```
374#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
375pub struct Level(LevelInner);
376
377impl Level {
378    const fn scalar() -> Self {
379        Self(LevelInner::Scalar)
380    }
381}
382
383cfg_if::cfg_if! {
384    if #[cfg(target_arch = "x86_64")] {
385        // Delegate to the architecture selection within the `x86_64` module.
386        pub mod x86_64;
387
388        use x86_64::LevelInner;
389
390        pub use x86_64::current;
391        pub use x86_64::Current;
392
393        pub use x86_64::dispatch;
394        pub use x86_64::dispatch1;
395        pub use x86_64::dispatch2;
396        pub use x86_64::dispatch3;
397
398        pub use x86_64::dispatch_no_features;
399        pub use x86_64::dispatch1_no_features;
400        pub use x86_64::dispatch2_no_features;
401        pub use x86_64::dispatch3_no_features;
402
403        impl Level {
404            const fn v3() -> Self {
405                Self(LevelInner::V3)
406            }
407
408            const fn v4() -> Self {
409                Self(LevelInner::V4)
410            }
411        }
412    } else if #[cfg(target_arch = "aarch64")] {
413        // Delegate to the architecture selection within the `aarch64` module.
414        pub mod aarch64;
415
416        use aarch64::LevelInner;
417
418        pub use aarch64::current;
419        pub use aarch64::Current;
420
421        pub use aarch64::dispatch;
422        pub use aarch64::dispatch1;
423        pub use aarch64::dispatch2;
424        pub use aarch64::dispatch3;
425
426        pub use aarch64::dispatch_no_features;
427        pub use aarch64::dispatch1_no_features;
428        pub use aarch64::dispatch2_no_features;
429        pub use aarch64::dispatch3_no_features;
430
431        impl Level {
432            const fn neon() -> Self {
433                Self(LevelInner::Neon)
434            }
435        }
436    } else {
437        pub type Current = Scalar;
438
439        // There is only one architecture present in this mode.
440        #[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
441        enum LevelInner {
442            Scalar,
443        }
444
445        pub const fn current() -> Current {
446            Scalar::new()
447        }
448
449        /// Run the target functor.
450        ///
451        /// In scalar mode, this does nothing special.
452        pub fn dispatch<T, R>(f: T) -> R
453        where T: Target<Scalar, R> {
454            f.run(Scalar::new())
455        }
456
457        /// Run the target functor.
458        ///
459        /// In scalar mode, this does nothing special.
460        pub fn dispatch1<T, T0, R>(f: T, x0: T0) -> R
461        where T: Target1<Scalar, R, T0> {
462            f.run(Scalar::new(), x0)
463        }
464
465        /// Run the target functor.
466        ///
467        /// In scalar mode, this does nothing special.
468        pub fn dispatch2<T, T0, T1, R>(f: T, x0: T0, x1: T1) -> R
469        where T: Target2<Scalar, R, T0, T1> {
470            f.run(Scalar::new(), x0, x1)
471        }
472
473        /// Run the target functor.
474        ///
475        /// In scalar mode, this does nothing special.
476        pub fn dispatch3<T, T0, T1, T2, R>(f: T, x0: T0, x1: T1, x2: T2) -> R
477        where T: Target3<Scalar, R, T0, T1, T2> {
478            f.run(Scalar::new(), x0, x1, x2)
479        }
480
481        /// Run the target functor.
482        ///
483        /// In scalar mode, this does nothing special.
484        pub fn dispatch_no_features<T, R>(f: T) -> R
485        where T: Target<Scalar, R> {
486            f.run(Scalar::new())
487        }
488
489        /// Run the target functor.
490        ///
491        /// In scalar mode, this does nothing special.
492        pub fn dispatch1_no_features<T, T0, R>(f: T, x0: T0) -> R
493        where T: Target1<Scalar, R, T0> {
494            f.run(Scalar::new(), x0)
495        }
496
497        /// Run the target functor.
498        ///
499        /// In scalar mode, this does nothing special.
500        pub fn dispatch2_no_features<T, T0, T1, R>(f: T, x0: T0, x1: T1) -> R
501        where T: Target2<Scalar, R, T0, T1> {
502            f.run(Scalar::new(), x0, x1)
503        }
504
505        /// Run the target functor.
506        ///
507        /// In scalar mode, this does nothing special.
508        pub fn dispatch3_no_features<T, T0, T1, T2, R>(f: T, x0: T0, x1: T1, x2: T2) -> R
509        where T: Target3<Scalar, R, T0, T1, T2> {
510            f.run(Scalar::new(), x0, x1, x2)
511        }
512    }
513}
514
515mod sealed {
516    pub trait Sealed: std::fmt::Debug + Copy + PartialEq + Send + Sync + 'static {}
517}
518
519pub(crate) use sealed::Sealed;
520
521macro_rules! vector {
522    ($me:ident: <$self:ident, $T:ty, $N:literal, $mask:ident> + $($rest:tt)*) => {
523        type $me: SIMDVector<Arch = $self, Scalar = $T, ConstLanes = Const<$N>, Mask = Self::$mask> + $($rest)*;
524    }
525}
526
527#[allow(non_camel_case_types)]
528pub trait Architecture: sealed::Sealed {
529    // mask types
530    type mask_f16x8: SIMDMask;
531    type mask_f16x16: SIMDMask;
532
533    type mask_f32x4: SIMDMask + SIMDSelect<Self::f32x4>;
534    type mask_f32x8: SIMDMask + SIMDSelect<Self::f32x8>;
535    type mask_f32x16: SIMDMask + SIMDSelect<Self::f32x16>;
536
537    type mask_i8x16: SIMDMask;
538    type mask_i8x32: SIMDMask;
539    type mask_i8x64: SIMDMask;
540
541    type mask_i16x8: SIMDMask;
542    type mask_i16x16: SIMDMask;
543    type mask_i16x32: SIMDMask;
544
545    type mask_i32x4: SIMDMask;
546    type mask_i32x8: SIMDMask + From<Self::mask_f32x8> + SIMDSelect<Self::i32x8>;
547    type mask_i32x16: SIMDMask + SIMDSelect<Self::i32x16>;
548
549    type mask_u8x16: SIMDMask;
550    type mask_u8x32: SIMDMask;
551    type mask_u8x64: SIMDMask;
552
553    type mask_u32x4: SIMDMask;
554    type mask_u32x8: SIMDMask + From<Self::mask_f32x8>;
555    type mask_u32x16: SIMDMask + SIMDSelect<Self::u32x16>;
556    type mask_u64x2: SIMDMask;
557    type mask_u64x4: SIMDMask;
558
559    /////////////////
560    //-- vectors --//
561    /////////////////
562
563    // floats
564    vector!(
565        f16x8: <Self, f16, 8, mask_f16x8>
566        + SIMDCast<f32, Cast = Self::f32x8>
567    );
568    vector!(
569        f16x16: <Self, f16, 16, mask_f16x16>
570        + SplitJoin<Halved = Self::f16x8>
571        + ZipUnzip<Halved = Self::f16x8>
572        + SIMDCast<f32, Cast = Self::f32x16>
573    );
574
575    vector!(
576        f32x4: <Self, f32, 4, mask_f32x4>
577        + SIMDFloat
578        + SIMDSumTree
579    );
580    vector!(
581        f32x8: <Self, f32, 8, mask_f32x8>
582        + SIMDFloat
583        + SIMDSumTree
584        + SIMDCast<f16, Cast = Self::f16x8>
585        + SplitJoin<Halved = Self::f32x4>
586        + From<Self::f16x8>
587    );
588    vector!(
589        f32x16: <Self, f32, 16, mask_f32x16>
590        + SIMDFloat
591        + SplitJoin<Halved = Self::f32x8>
592        + SIMDSumTree
593        + From<Self::f16x16>
594    );
595
596    // signed-integer
597    vector!(
598        i8x16: <Self, i8, 16, mask_i8x16>
599        + SIMDSigned
600    );
601    vector!(
602        i8x32: <Self, i8, 32, mask_i8x32>
603        + SIMDSigned
604        + SplitJoin<Halved = Self::i8x16>
605        + ZipUnzip<Halved = Self::i8x16>
606    );
607    vector!(
608        i8x64: <Self, i8, 64, mask_i8x64>
609        + SIMDSigned
610    );
611
612    vector!(
613        i16x8: <Self, i16, 8, mask_i16x8>
614        + SIMDSigned
615    );
616    vector!(
617        i16x16: <Self, i16, 16, mask_i16x16>
618        + SIMDSigned
619        + SplitJoin<Halved = Self::i16x8>
620        + ZipUnzip<Halved = Self::i16x8>
621        + From<Self::i8x16>
622        + From<Self::u8x16>
623    );
624    vector!(
625        i16x32: <Self, i16, 32, mask_i16x32>
626        + SIMDSigned
627        + SplitJoin<Halved = Self::i16x16>
628        + From<Self::i8x32>
629        + From<Self::u8x32>
630    );
631
632    vector!(
633        i32x4: <Self, i32, 4, mask_i32x4>
634        + SIMDSigned
635    );
636    vector!(
637        i32x8: <Self, i32, 8, mask_i32x8>
638        + SIMDSigned
639        + SIMDSumTree
640        + SplitJoin<Halved = Self::i32x4>
641        + ZipUnzip<Halved = Self::i32x4>
642        + SIMDDotProduct<Self::i16x16>
643        + SIMDDotProduct<Self::u8x32, Self::i8x32>
644        + SIMDDotProduct<Self::i8x32, Self::u8x32>
645        + SIMDCast<f32, Cast = Self::f32x8>
646    );
647    vector!(
648        i32x16: <Self, i32, 16, mask_i32x16>
649        + SIMDSigned
650        + SIMDSumTree
651        + SplitJoin<Halved = Self::i32x8>
652        + SIMDDotProduct<Self::u8x64, Self::i8x64>
653        + SIMDDotProduct<Self::i8x64, Self::u8x64>
654    );
655
656    // unsigned-integer
657    vector!(
658        u8x16: <Self, u8, 16, mask_u8x16>
659        + SIMDUnsigned
660    );
661    vector!(
662        u8x32: <Self, u8, 32, mask_u8x32>
663        + SIMDUnsigned
664        + SplitJoin<Halved = Self::u8x16>
665        + ZipUnzip<Halved = Self::u8x16>
666    );
667
668    vector!(
669        u8x64: <Self, u8, 64, mask_u8x64>
670        + SIMDUnsigned
671    );
672
673    vector!(
674        u32x4: <Self, u32, 4, mask_u32x4>
675        + SIMDUnsigned
676    );
677    vector!(
678        u32x8: <Self, u32, 8, mask_u32x8>
679        + SplitJoin<Halved = Self::u32x4>
680        + ZipUnzip<Halved = Self::u32x4>
681        + SIMDUnsigned
682        + SIMDSumTree
683    );
684    vector!(
685        u32x16: <Self, u32, 16, mask_u32x16>
686        + SIMDUnsigned
687        + SIMDSumTree
688        + SplitJoin<Halved = Self::u32x8>
689    );
690
691    vector!(
692        u64x2: <Self, u64, 2, mask_u64x2>
693        + SIMDUnsigned
694    );
695    vector!(
696        u64x4: <Self, u64, 4, mask_u64x4>
697        + SplitJoin<Halved = Self::u64x2>
698        + SIMDUnsigned
699    );
700
701    //---------//
702    // Methods //
703    //---------//
704
705    /// Return an opaque [`Level`] representing the capabilities of this architecture.
706    ///
707    /// Levels that compare greater represent architectures that are more capable.
708    ///
709    /// # Examples
710    ///
711    /// ```
712    /// use diskann_wide::{Architecture, arch};
713    ///
714    /// // Scalar is the baseline — every other architecture compares greater.
715    /// assert_eq!(arch::Scalar::level(), arch::Scalar::level());
716    ///
717    /// #[cfg(target_arch = "x86_64")]
718    /// assert!(arch::Scalar::level() < arch::x86_64::V3::level());
719    /// ```
720    fn level() -> Level;
721
722    /// Run the provided closure targeting this architecture.
723    ///
724    /// This function is always safe to call, but the function `f` likely needs to be
725    /// inlined into `run` for the correct target features to be applied.
726    fn run<F, R>(self, f: F) -> R
727    where
728        F: Target<Self, R>;
729
730    /// Run the provided closure targeting this architecture with an inlining hint.
731    ///
732    /// This function is always safe to call, but the function `f` likely needs to be
733    /// inlined into `run` for the correct target features to be applied.
734    ///
735    /// Note that although an inline hint is applied, it is not guaranteed that this call
736    /// will be inlined due to the interaction of `target_features`. If you really need `F`
737    /// to be inlined, you can call its `Target` method directly, but care must be taken
738    /// because this will not reapply `target_features`.
739    fn run_inline<F, R>(self, f: F) -> R
740    where
741        F: Target<Self, R>;
742
743    /// Run the provided closure targeting this architecture with an additional argument.
744    ///
745    /// This function is always safe to call, but the function `f` likely needs to be
746    /// inlined into `run` for the correct target features to be applied.
747    fn run1<F, T0, R>(self, f: F, x0: T0) -> R
748    where
749        F: Target1<Self, R, T0>;
750
751    /// Run the provided closure targeting this architecture with an additional argument and
752    /// an inlining hint.
753    ///
754    /// This function is always safe to call, but the function `f` likely needs to be
755    /// inlined into `run` for the correct target features to be applied.
756    ///
757    /// Note that although an inline hint is applied, it is not guaranteed that this call
758    /// will be inlined due to the interaction of `target_features`. If you really need `F`
759    /// to be inlined, you can call its `Target1` method directly, but care must be taken
760    /// because this will not reapply `target_features`.
761    fn run1_inline<F, T0, R>(self, f: F, x0: T0) -> R
762    where
763        F: Target1<Self, R, T0>;
764
765    /// Run the provided closure targeting this architecture with two additional arguments.
766    ///
767    /// This function is always safe to call, but the function `f` likely needs to be
768    /// inlined into `run` for the correct target features to be applied.
769    fn run2<F, T0, T1, R>(self, f: F, x0: T0, x1: T1) -> R
770    where
771        F: Target2<Self, R, T0, T1>;
772
773    /// Run the provided closure targeting this architecture with two additional arguments
774    /// and an inlining hint.
775    ///
776    /// This function is always safe to call, but the function `f` likely needs to be
777    /// inlined into `run` for the correct target features to be applied.
778    ///
779    /// Note that although an inline hint is applied, it is not guaranteed that this call
780    /// will be inlined due to the interaction of `target_features`. If you really need `F`
781    /// to be inlined, you can call its `Target2` method directly, but care must be taken
782    /// because this will not reapply `target_features`.
783    fn run2_inline<F, T0, T1, R>(self, f: F, x0: T0, x1: T1) -> R
784    where
785        F: Target2<Self, R, T0, T1>;
786
787    /// Run the provided closure targeting this architecture with three additional arguments.
788    ///
789    /// This function is always safe to call, but the function `f` likely needs to be
790    /// inlined into `run` for the correct target features to be applied.
791    fn run3<F, T0, T1, T2, R>(self, f: F, x0: T0, x1: T1, x2: T2) -> R
792    where
793        F: Target3<Self, R, T0, T1, T2>;
794
795    /// Run the provided closure targeting this architecture with three additional arguments
796    /// and an inlining hint.
797    ///
798    /// This function is always safe to call, but the function `f` likely needs to be
799    /// inlined into `run` for the correct target features to be applied.
800    ///
801    /// Note that although an inline hint is applied, it is not guaranteed that this call
802    /// will be inlined due to the interaction of `target_features`. If you really need `F`
803    /// to be inlined, you can call its `Target3` method directly, but care must be taken
804    /// because this will not reapply `target_features`.
805    fn run3_inline<F, T0, T1, T2, R>(self, f: F, x0: T0, x1: T1, x2: T2) -> R
806    where
807        F: Target3<Self, R, T0, T1, T2>;
808
809    /// Return a function pointer invoking [`FTarget1::run`] with `self` as the architecture.
810    /// ```
811    /// use diskann_wide::{Architecture, arch::FTarget1};
812    ///
813    /// struct Square;
814    ///
815    /// impl<A: Architecture> FTarget1<A, f32, f32> for Square
816    /// {
817    ///     fn run(_: A, x: f32) -> f32 {
818    ///         x * x
819    ///     }
820    /// }
821    ///
822    /// let f = (diskann_wide::ARCH).dispatch1::<Square, f32, f32>();
823    /// assert_eq!(f.call(10.0), 100.0);
824    /// ```
825    fn dispatch1<F, R, T0>(self) -> Dispatched1<R, T0>
826    where
827        T0: AddLifetime,
828        F: for<'a> FTarget1<Self, R, T0::Of<'a>>;
829
830    /// Return a function pointer invoking [`FTarget2::run`] with `self` as the architecture.
831    /// ```
832    /// use diskann_wide::{
833    ///     Architecture,
834    ///     arch::FTarget2,
835    ///     lifetime::{Mut, Ref},
836    /// };
837    ///
838    /// struct Copy;
839    ///
840    /// // Copy a slice and return the number of elements.
841    /// impl<A: Architecture> FTarget2<A, usize, &mut [f32], &[f32]> for Copy
842    /// {
843    ///     fn run(_: A, dst: &mut [f32], src: &[f32]) -> usize {
844    ///         dst.copy_from_slice(src);
845    ///         src.len()
846    ///     }
847    /// }
848    ///
849    /// let f = (diskann_wide::ARCH).dispatch2::<Copy, usize, Mut<[f32]>, Ref<[f32]>>();
850    /// let src = [1.0, 2.0, 3.0];
851    /// let mut dst = [0.0f32; 3];
852    /// assert_eq!(f.call(&mut dst, &src), 3);
853    /// assert_eq!(src, dst);
854    /// ```
855    fn dispatch2<F, R, T0, T1>(self) -> Dispatched2<R, T0, T1>
856    where
857        T0: AddLifetime,
858        T1: AddLifetime,
859        F: for<'a, 'b> FTarget2<Self, R, T0::Of<'a>, T1::Of<'b>>;
860
861    /// Return a function pointer invoking [`FTarget3::run`] with `self` as the architecture.
862    /// ```
863    /// use diskann_wide::{
864    ///     Architecture,
865    ///     arch::FTarget3,
866    ///     lifetime::Ref,
867    /// };
868    ///
869    /// struct Sum;
870    ///
871    /// // Return the sum of the three arguments.
872    /// impl<A: Architecture> FTarget3<A, f32, &f32, &f32, &f32> for Sum
873    /// {
874    ///     fn run(_: A, x: &f32, y: &f32, z: &f32) -> f32 {
875    ///         x + y + z
876    ///     }
877    /// }
878    ///
879    /// let f = (diskann_wide::ARCH).dispatch3::<Sum, f32, Ref<f32>, Ref<f32>, Ref<f32>>();
880    /// assert_eq!(f.call(&1.0, &2.0, &3.0), 6.0);
881    /// ```
882    fn dispatch3<F, R, T0, T1, T2>(self) -> Dispatched3<R, T0, T1, T2>
883    where
884        T0: AddLifetime,
885        T1: AddLifetime,
886        T2: AddLifetime,
887        F: for<'a, 'b, 'c> FTarget3<Self, R, T0::Of<'a>, T1::Of<'b>, T2::Of<'c>>;
888}
889
890/// A functor that targets a particular architecture, accepting no additional arguments.
891pub trait Target<A, R>
892where
893    A: Architecture,
894{
895    /// Run the operation with the provided `Architecture`.
896    fn run(self, arch: A) -> R;
897}
898
899/// A functor that targets a particular architecture, accepting one additional arguments.
900pub trait Target1<A, R, T0>
901where
902    A: Architecture,
903{
904    /// Run the operation with the provided `Architecture`.
905    fn run(self, arch: A, x0: T0) -> R;
906}
907
908/// A functor that targets a particular architecture, accepting two additional arguments.
909pub trait Target2<A, R, T0, T1>
910where
911    A: Architecture,
912{
913    /// Run the operation with the provided `Architecture`.
914    fn run(self, arch: A, x0: T0, x1: T1) -> R;
915}
916
917/// A functor that targets a particular architecture, accepting three additional arguments.
918pub trait Target3<A, R, T0, T1, T2>
919where
920    A: Architecture,
921{
922    /// Run the operation with the provided `Architecture`.
923    fn run(self, arch: A, x0: T0, x1: T1, x2: T2) -> R;
924}
925
926/// A variation of [`Target1`] that uses an associated function instead of a method.
927///
928/// This is used in the function pointer API.
929pub trait FTarget1<A, R, T0>
930where
931    A: Architecture,
932{
933    fn run(arch: A, x0: T0) -> R;
934}
935
936/// A variation of [`Target2`] that uses an associated function instead of a method.
937///
938/// This is used in the function pointer API.
939pub trait FTarget2<A, R, T0, T1>
940where
941    A: Architecture,
942{
943    fn run(arch: A, x0: T0, x1: T1) -> R;
944}
945
946/// A variation of [`Target3`] that uses an associated function instead of a method.
947///
948/// This is used in the function pointer API.
949pub trait FTarget3<A, R, T0, T1, T2>
950where
951    A: Architecture,
952{
953    fn run(arch: A, x0: T0, x1: T1, x2: T2) -> R;
954}
955
956/// Run the closure with code-generated for the specified architecture.
957///
958/// Note that if the body of the closure is not inlined, this will likely have no effect.
959impl<A, R, F> Target<A, R> for F
960where
961    A: Architecture,
962    F: FnOnce() -> R,
963{
964    #[inline]
965    fn run(self, _: A) -> R {
966        (self)()
967    }
968}
969
970/// Run the closure with code-generated for the specified architecture.
971///
972/// Note that if the body of the closure is not inlined, this will likely have no effect.
973impl<A, R, T0, F> Target1<A, R, T0> for F
974where
975    A: Architecture,
976    F: FnOnce(T0) -> R,
977{
978    #[inline]
979    fn run(self, _: A, x0: T0) -> R {
980        (self)(x0)
981    }
982}
983
984/// Run the closure with code-generated for the specified architecture.
985///
986/// Note that if the body of the closure is not inlined, this will likely have no effect.
987impl<A, R, T0, T1, F> Target2<A, R, T0, T1> for F
988where
989    A: Architecture,
990    F: FnOnce(T0, T1) -> R,
991{
992    #[inline]
993    fn run(self, _: A, x0: T0, x1: T1) -> R {
994        (self)(x0, x1)
995    }
996}
997
998/// Run the closure with code-generated for the specified architecture.
999///
1000/// Note that if the body of the closure is not inlined, this will likely have no effect.
1001impl<A, R, T0, T1, T2, F> Target3<A, R, T0, T1, T2> for F
1002where
1003    A: Architecture,
1004    F: FnOnce(T0, T1, T2) -> R,
1005{
1006    #[inline]
1007    fn run(self, _: A, x0: T0, x1: T1, x2: T2) -> R {
1008        (self)(x0, x1, x2)
1009    }
1010}
1011
1012/// A hidden architecture for use in the function pointer API.
1013#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
1014struct Hidden;
1015
1016const _ASSERT_ZST: () = assert!(
1017    std::mem::size_of::<Hidden>() == 0,
1018    "Hidden **must** be zero sized"
1019);
1020
1021const _ASSERT_ALIGNED: () = assert!(
1022    std::mem::align_of::<Hidden>() == 1,
1023    "Hidden **must** be alignment 1"
1024);
1025
1026macro_rules! dispatched {
1027    ($name:ident, { $($Ts:ident )* }, { $($xs:ident )* }, { $($lt:lifetime )* }) => {
1028        /// A function pointer that calls directly into a micro-architecture optimized
1029        /// function, returning a value of type `R` and accepting the specified number of
1030        /// arguments.
1031        ///
1032        /// Arguments are mapped using the [`AddLifetime`] trait to enable passing structs
1033        /// with up to a single non-`'static` lifetime parameter.
1034        ///
1035        /// This type is guaranteed:
1036        ///
1037        /// * To have the same size as a regular function pointer.
1038        /// * To have the same ABI as a regular function pointer.
1039        /// * To use the null-pointer optimization - so `Option<Self>` has the same size as
1040        ///   `Self`.
1041        #[derive(Debug)]
1042        #[repr(transparent)]
1043        pub struct $name<R, $($Ts,)*>
1044        where
1045            $($Ts: AddLifetime,)*
1046        {
1047            f: for<$($lt,)*> unsafe fn(Hidden, $($Ts::Of<$lt>,)*) -> R,
1048        }
1049
1050        impl<R, $($Ts,)*> $name<R, $($Ts,)*>
1051        where
1052            $($Ts: AddLifetime,)*
1053        {
1054            /// Construct a new safe instance of `Self` around the raw function pointer.
1055            ///
1056            /// # Safety
1057            ///
1058            /// The caller asserts that the runtime implementation of `f` is safe to call
1059            /// on the current CPU target.
1060            ///
1061            /// Usually, this means that the runtime CPU has the target features required
1062            /// for the destination of `f`.
1063            unsafe fn new(f: unsafe fn(Hidden, $($Ts::Of<'_>,)*) -> R) -> Self {
1064                Self { f }
1065            }
1066
1067            /// Invoke the function with the lifetime-annotated arguments and return the
1068            /// result.
1069            ///
1070            /// The below example demonstrates the behavior for [`Dispatched1`], but the
1071            /// same logic applies to all variadic instances.
1072            #[inline(always)]
1073            pub fn call(self, $($xs: $Ts::Of<'_>,)*) -> R {
1074                // SAFETY: The constructor of `Dispatched` asserts the call is safe.
1075                unsafe { (self.f)(Hidden, $($xs,)*) }
1076            }
1077        }
1078
1079        impl<R, $($Ts,)*> Clone for $name<R, $($Ts,)*>
1080        where
1081            $($Ts: AddLifetime,)*
1082        {
1083            fn clone(&self) -> Self {
1084                *self
1085            }
1086        }
1087
1088        impl<R, $($Ts,)*> Copy for $name<R, $($Ts,)*>
1089        where
1090            $($Ts: AddLifetime,)*
1091        {
1092        }
1093    }
1094}
1095
1096dispatched!(Dispatched1, { T0 }, { x0 }, { 'a0 });
1097dispatched!(Dispatched2, { T0 T1 }, { x0 x1 }, { 'a0 'a1 });
1098dispatched!(Dispatched3, { T0 T1 T2 }, { x0 x1 x2 }, { 'a0 'a1 'a2 });
1099
1100/// This macro stamps out the function-pointer transmute trick we use to type-erase
1101/// architecture in the function-pointer API.
1102macro_rules! hide {
1103    ($name:ident, $dispatched:ident, { $($Ts:ident )* }) => {
1104        /// Construct a new instance of [`Self`] from the raw function pointer.
1105        ///
1106        /// # Safety
1107        ///
1108        /// Internally, we will transmute the zero-sized type `A` to another hidden type
1109        /// to erase the architecture information of the dispatched function pointer.
1110        ///
1111        /// This function will not compile if `A` is not zero sized nor has an alignment
1112        /// of 1.
1113        ///
1114        /// We can do this because Rust guarantees that zero sized types are ABI
1115        /// compatible.
1116        ///
1117        /// The caller must ensure that winking into existence an instance of `A` is
1118        /// a safe operation. For [`Architectures`], this means that the requirements
1119        /// of `A::new()` are upheld.
1120        ///
1121        /// Put plainly:
1122        ///
1123        /// 1. The runtime CPU must support the target features required by `A`.
1124        /// 2. `A` **must** be a zero-sized type with an alignment of 1.
1125        unsafe fn $name<A, R, $($Ts,)*>(
1126            f: unsafe fn(A, $($Ts::Of<'_>,)*) -> R
1127        ) -> $dispatched<R, $($Ts,)*>
1128        where
1129            $($Ts: AddLifetime,)*
1130        {
1131            // Check that `A` is a zero sized type.
1132            const {
1133                assert!(
1134                    std::mem::size_of::<A>() == 0,
1135                    "A must be zero sized to be ABI compatible with `Hidden`"
1136                )
1137            };
1138
1139            // Check that `A` has alignment 1.
1140            const {
1141                assert!(
1142                    std::mem::align_of::<A>() == 1,
1143                    "A must have an alignment of 1 to be ABI compatible with `Hidden`"
1144                )
1145            };
1146
1147            // SAFETY: The transmute is safe because `Hidden` and `A` are both zero
1148            // sized types with alignment 1, which are ABI compatible.
1149            //
1150            // All the rest of the arguments are untouched.
1151            //
1152            // The caller asserts that it is safe to call this function pointer.
1153            let f = unsafe {
1154                std::mem::transmute::<
1155                    unsafe fn(A, $($Ts::Of<'_>,)*) -> R,
1156                    unsafe fn(Hidden, $($Ts::Of<'_>,)*) -> R
1157                >(f)
1158            };
1159
1160            // SAFETY: The caller asserts that it is safe to call `f`.
1161            unsafe { $dispatched::new(f) }
1162        }
1163    }
1164}
1165
1166hide!(hide1, Dispatched1, { T0 });
1167hide!(hide2, Dispatched2, { T0 T1 });
1168hide!(hide3, Dispatched3, { T0 T1 T2 });
1169
1170// Macros to help implement architectures.
1171
1172macro_rules! maskdef {
1173    ($mask:ident = $repr:ty) => {
1174        type $mask = <$repr as SIMDVector>::Mask;
1175    };
1176    ($($mask:ident = $repr:ty),+ $(,)?) => {
1177        $($crate::arch::maskdef!($mask = $repr);)+
1178    };
1179    () => {
1180        $crate::arch::maskdef!(
1181            mask_f16x8 = f16x8,
1182            mask_f16x16 = f16x16,
1183
1184            mask_f32x4 = f32x4,
1185            mask_f32x8 = f32x8,
1186            mask_f32x16 = f32x16,
1187
1188            mask_i8x16 = i8x16,
1189            mask_i8x32 = i8x32,
1190            mask_i8x64 = i8x64,
1191
1192            mask_i16x8 = i16x8,
1193            mask_i16x16 = i16x16,
1194            mask_i16x32 = i16x32,
1195
1196            mask_i32x4 = i32x4,
1197            mask_i32x8 = i32x8,
1198            mask_i32x16 = i32x16,
1199
1200            mask_u8x16 = u8x16,
1201            mask_u8x32 = u8x32,
1202            mask_u8x64 = u8x64,
1203
1204            mask_u32x4 = u32x4,
1205            mask_u32x8 = u32x8,
1206            mask_u32x16 = u32x16,
1207
1208            mask_u64x2 = u64x2,
1209            mask_u64x4 = u64x4,
1210        );
1211    };
1212}
1213
1214macro_rules! typedef {
1215    () => {
1216        $crate::arch::typedef!(
1217            f16x8,
1218            f16x16,
1219
1220            f32x4,
1221            f32x8,
1222            f32x16,
1223
1224            i8x16,
1225            i8x32,
1226            i8x64,
1227
1228            i16x8,
1229            i16x16,
1230            i16x32,
1231
1232            i32x4,
1233            i32x8,
1234            i32x16,
1235
1236            u8x16,
1237            u8x32,
1238            u8x64,
1239
1240            u32x4,
1241            u32x8,
1242            u32x16,
1243
1244            u64x2,
1245            u64x4,
1246        );
1247    };
1248    ($repr:ident) => {
1249        type $repr = $repr;
1250    };
1251    ($($repr:ident),+ $(,)?) => {
1252        $($crate::arch::typedef!($repr);)+
1253    };
1254}
1255
1256pub(crate) use maskdef;
1257pub(crate) use typedef;
1258
1259///////////
1260// Tests //
1261///////////
1262
1263// All tests here **must** run successfully under Miri.
1264#[cfg(test)]
1265mod tests {
1266    use super::*;
1267    use crate::lifetime::{Mut, Ref};
1268
1269    struct TestOp;
1270
1271    // Returns a static string.
1272    impl<A> Target<A, &'static str> for TestOp
1273    where
1274        A: Architecture,
1275    {
1276        fn run(self, _: A) -> &'static str {
1277            "hello world"
1278        }
1279    }
1280
1281    // Simply add all elements in the array and return the result.
1282    impl<A> Target1<A, f32, &[f32]> for TestOp
1283    where
1284        A: Architecture,
1285    {
1286        fn run(self, _: A, x: &[f32]) -> f32 {
1287            x.iter().sum()
1288        }
1289    }
1290
1291    impl<A> FTarget1<A, f32, &[f32]> for TestOp
1292    where
1293        A: Architecture,
1294    {
1295        fn run(arch: A, x: &[f32]) -> f32 {
1296            <_ as Target1<_, _, _>>::run(Self, arch, x)
1297        }
1298    }
1299
1300    // Both perform a sum and copy the results into the destination.
1301    impl<A> Target2<A, f32, &mut [f32], &[f32]> for TestOp
1302    where
1303        A: Architecture,
1304    {
1305        fn run(self, _: A, x: &mut [f32], y: &[f32]) -> f32 {
1306            x.copy_from_slice(y);
1307            y.iter().sum()
1308        }
1309    }
1310
1311    impl<A> FTarget2<A, f32, &mut [f32], &[f32]> for TestOp
1312    where
1313        A: Architecture,
1314    {
1315        fn run(arch: A, x: &mut [f32], y: &[f32]) -> f32 {
1316            <_ as Target2<_, _, _, _>>::run(TestOp, arch, x, y)
1317        }
1318    }
1319
1320    impl<A> Target3<A, f32, &mut [f32], &[f32], f32> for TestOp
1321    where
1322        A: Architecture,
1323    {
1324        fn run(self, _: A, x: &mut [f32], y: &[f32], z: f32) -> f32 {
1325            assert_eq!(x.len(), y.len());
1326            x.iter_mut().zip(y.iter()).for_each(|(d, s)| *d = *s + z);
1327            y.iter().sum()
1328        }
1329    }
1330
1331    impl<A> FTarget3<A, f32, &mut [f32], &[f32], f32> for TestOp
1332    where
1333        A: Architecture,
1334    {
1335        fn run(arch: A, x: &mut [f32], y: &[f32], z: f32) -> f32 {
1336            <_ as Target3<_, _, _, _, _>>::run(TestOp, arch, x, y, z)
1337        }
1338    }
1339
1340    //------------------//
1341    // Target Interface //
1342    //------------------//
1343
1344    #[test]
1345    fn zero_arg_target() {
1346        let expected = "hello world";
1347        assert_eq!((Scalar).run(TestOp), expected);
1348        assert_eq!((Scalar).run_inline(TestOp), expected);
1349
1350        #[cfg(target_arch = "x86_64")]
1351        if let Some(arch) = x86_64::V3::new_checked_uncached() {
1352            assert_eq!(arch.run(TestOp), expected);
1353            assert_eq!(arch.run_inline(TestOp), expected);
1354        }
1355
1356        #[cfg(target_arch = "x86_64")]
1357        if let Some(arch) = x86_64::V4::new_checked_miri() {
1358            assert_eq!(arch.run(TestOp), expected);
1359            assert_eq!(arch.run_inline(TestOp), expected);
1360        }
1361
1362        #[cfg(target_arch = "aarch64")]
1363        if let Some(arch) = aarch64::Neon::new_checked() {
1364            assert_eq!(arch.run(TestOp), expected);
1365            assert_eq!(arch.run_inline(TestOp), expected);
1366        }
1367    }
1368
1369    #[test]
1370    fn one_arg_target() {
1371        let src = [1.0f32, 2.0f32, 3.0f32];
1372        let sum: f32 = src.iter().sum();
1373
1374        assert_eq!((Scalar).run1(TestOp, &src), sum);
1375        assert_eq!((Scalar).run1_inline(TestOp, &src), sum);
1376
1377        #[cfg(target_arch = "x86_64")]
1378        if let Some(arch) = x86_64::V3::new_checked_uncached() {
1379            assert_eq!(arch.run1(TestOp, &src), sum);
1380            assert_eq!(arch.run1_inline(TestOp, &src), sum);
1381        }
1382
1383        #[cfg(target_arch = "x86_64")]
1384        if let Some(arch) = x86_64::V4::new_checked_miri() {
1385            assert_eq!(arch.run1(TestOp, &src), sum);
1386            assert_eq!(arch.run1_inline(TestOp, &src), sum);
1387        }
1388
1389        #[cfg(target_arch = "aarch64")]
1390        if let Some(arch) = aarch64::Neon::new_checked() {
1391            assert_eq!(arch.run1(TestOp, &src), sum);
1392            assert_eq!(arch.run1_inline(TestOp, &src), sum);
1393        }
1394    }
1395
1396    #[test]
1397    fn two_arg_target() {
1398        let src = [1.0f32, 2.0f32, 3.0f32];
1399        let sum: f32 = src.iter().sum();
1400
1401        macro_rules! gen_test {
1402            ($arch:ident) => {{
1403                let mut dst = [0.0f32; 3];
1404                assert_eq!($arch.run2(TestOp, &mut dst, &src), sum);
1405                assert_eq!(dst, src);
1406            }
1407
1408            {
1409                let mut dst = [0.0f32; 3];
1410                assert_eq!($arch.run2_inline(TestOp, &mut dst, &src), sum);
1411                assert_eq!(dst, src);
1412            }};
1413        }
1414
1415        gen_test!(Scalar);
1416
1417        #[cfg(target_arch = "x86_64")]
1418        if let Some(arch) = x86_64::V3::new_checked_uncached() {
1419            gen_test!(arch);
1420        }
1421
1422        #[cfg(target_arch = "x86_64")]
1423        if let Some(arch) = x86_64::V4::new_checked_miri() {
1424            gen_test!(arch);
1425        }
1426
1427        #[cfg(target_arch = "aarch64")]
1428        if let Some(arch) = aarch64::Neon::new_checked() {
1429            gen_test!(arch);
1430        }
1431    }
1432
1433    #[test]
1434    fn three_arg_target() {
1435        let src = [1.0f32, 2.0f32, 3.0f32];
1436        let sum: f32 = src.iter().sum();
1437        let offset = 10.0f32;
1438        let expected = [11.0f32, 12.0f32, 13.0f32];
1439
1440        macro_rules! gen_test {
1441            ($arch:ident) => {{
1442                let mut dst = [0.0f32; 3];
1443                assert_eq!($arch.run3(TestOp, &mut dst, &src, offset), sum);
1444                assert_eq!(dst, expected);
1445            }
1446
1447            {
1448                let mut dst = [0.0f32; 3];
1449                assert_eq!($arch.run3_inline(TestOp, &mut dst, &src, offset), sum);
1450                assert_eq!(dst, expected);
1451            }};
1452        }
1453
1454        gen_test!(Scalar);
1455
1456        #[cfg(target_arch = "x86_64")]
1457        if let Some(arch) = x86_64::V3::new_checked_uncached() {
1458            gen_test!(arch);
1459        }
1460
1461        #[cfg(target_arch = "x86_64")]
1462        if let Some(arch) = x86_64::V4::new_checked_miri() {
1463            gen_test!(arch);
1464        }
1465
1466        #[cfg(target_arch = "aarch64")]
1467        if let Some(arch) = aarch64::Neon::new_checked() {
1468            gen_test!(arch);
1469        }
1470    }
1471
1472    //----------------------------//
1473    // Function Pointer Interface //
1474    //----------------------------//
1475
1476    #[test]
1477    fn one_arg_function_pointer() {
1478        let src = [1.0f32, 2.0f32, 3.0f32];
1479        let sum: f32 = src.iter().sum();
1480
1481        type FnPtr = Dispatched1<f32, Ref<[f32]>>;
1482
1483        assert_eq!(std::mem::size_of::<FnPtr>(), std::mem::size_of::<fn()>());
1484        assert_eq!(
1485            std::mem::size_of::<Option<FnPtr>>(),
1486            std::mem::size_of::<fn()>()
1487        );
1488
1489        {
1490            let f: FnPtr = (Scalar).dispatch1::<TestOp, f32, Ref<[f32]>>();
1491            assert_eq!(f.call(&src), sum);
1492        }
1493
1494        #[cfg(target_arch = "x86_64")]
1495        if let Some(arch) = x86_64::V3::new_checked_uncached() {
1496            let f: FnPtr = arch.dispatch1::<TestOp, f32, Ref<[f32]>>();
1497            assert_eq!(f.call(&src), sum);
1498        }
1499
1500        #[cfg(target_arch = "x86_64")]
1501        if let Some(arch) = x86_64::V4::new_checked_miri() {
1502            let f: FnPtr = arch.dispatch1::<TestOp, f32, Ref<[f32]>>();
1503            assert_eq!(f.call(&src), sum);
1504        }
1505
1506        #[cfg(target_arch = "aarch64")]
1507        if let Some(arch) = aarch64::Neon::new_checked() {
1508            let f: FnPtr = arch.dispatch1::<TestOp, f32, Ref<[f32]>>();
1509            assert_eq!(f.call(&src), sum);
1510        }
1511    }
1512
1513    #[test]
1514    fn two_arg_function_pointer() {
1515        let src = [1.0f32, 2.0f32, 3.0f32];
1516        let sum: f32 = src.iter().sum();
1517
1518        type FnPtr = Dispatched2<f32, Mut<[f32]>, Ref<[f32]>>;
1519
1520        assert_eq!(std::mem::size_of::<FnPtr>(), std::mem::size_of::<fn()>());
1521        assert_eq!(
1522            std::mem::size_of::<Option<FnPtr>>(),
1523            std::mem::size_of::<fn()>()
1524        );
1525
1526        {
1527            let mut dst = [0.0f32; 3];
1528            let f: FnPtr = (Scalar).dispatch2::<TestOp, f32, Mut<[f32]>, Ref<[f32]>>();
1529            assert_eq!(f.call(&mut dst, &src), sum);
1530            assert_eq!(dst, src);
1531        }
1532
1533        #[cfg(target_arch = "x86_64")]
1534        if let Some(arch) = x86_64::V3::new_checked_uncached() {
1535            let mut dst = [0.0f32; 3];
1536            let f: FnPtr = arch.dispatch2::<TestOp, f32, Mut<[f32]>, Ref<[f32]>>();
1537            assert_eq!(f.call(&mut dst, &src), sum);
1538            assert_eq!(dst, src);
1539        }
1540
1541        #[cfg(target_arch = "x86_64")]
1542        if let Some(arch) = x86_64::V4::new_checked_miri() {
1543            let mut dst = [0.0f32; 3];
1544            let f: FnPtr = arch.dispatch2::<TestOp, f32, Mut<[f32]>, Ref<[f32]>>();
1545            assert_eq!(f.call(&mut dst, &src), sum);
1546            assert_eq!(dst, src);
1547        }
1548
1549        #[cfg(target_arch = "aarch64")]
1550        if let Some(arch) = aarch64::Neon::new_checked() {
1551            let mut dst = [0.0f32; 3];
1552            let f: FnPtr = arch.dispatch2::<TestOp, f32, Mut<[f32]>, Ref<[f32]>>();
1553            assert_eq!(f.call(&mut dst, &src), sum);
1554            assert_eq!(dst, src);
1555        }
1556    }
1557
1558    #[test]
1559    fn three_arg_function_pointer() {
1560        let src = [1.0f32, 2.0f32, 3.0f32];
1561        let sum: f32 = src.iter().sum();
1562        let offset = 10.0f32;
1563        let expected = [11.0f32, 12.0f32, 13.0f32];
1564
1565        type FnPtr = Dispatched3<f32, Mut<[f32]>, Ref<[f32]>, f32>;
1566
1567        assert_eq!(std::mem::size_of::<FnPtr>(), std::mem::size_of::<fn()>());
1568        assert_eq!(
1569            std::mem::size_of::<Option<FnPtr>>(),
1570            std::mem::size_of::<fn()>()
1571        );
1572
1573        {
1574            let mut dst = [0.0f32; 3];
1575            let f: FnPtr = (Scalar).dispatch3::<TestOp, f32, Mut<[f32]>, Ref<[f32]>, f32>();
1576            assert_eq!(f.call(&mut dst, &src, offset), sum);
1577            assert_eq!(dst, expected);
1578        }
1579
1580        #[cfg(target_arch = "x86_64")]
1581        if let Some(arch) = x86_64::V3::new_checked_uncached() {
1582            let mut dst = [0.0f32; 3];
1583            let f: FnPtr = arch.dispatch3::<TestOp, f32, Mut<[f32]>, Ref<[f32]>, f32>();
1584            assert_eq!(f.call(&mut dst, &src, offset), sum);
1585            assert_eq!(dst, expected);
1586        }
1587
1588        #[cfg(target_arch = "x86_64")]
1589        if let Some(arch) = x86_64::V4::new_checked_miri() {
1590            let mut dst = [0.0f32; 3];
1591            let f: FnPtr = arch.dispatch3::<TestOp, f32, Mut<[f32]>, Ref<[f32]>, f32>();
1592            assert_eq!(f.call(&mut dst, &src, offset), sum);
1593            assert_eq!(dst, expected);
1594        }
1595
1596        #[cfg(target_arch = "aarch64")]
1597        if let Some(arch) = aarch64::Neon::new_checked() {
1598            let mut dst = [0.0f32; 3];
1599            let f: FnPtr = arch.dispatch3::<TestOp, f32, Mut<[f32]>, Ref<[f32]>, f32>();
1600            assert_eq!(f.call(&mut dst, &src, offset), sum);
1601            assert_eq!(dst, expected);
1602        }
1603    }
1604}