Skip to main content

irithyll_core/simd/
ops.rs

1//! Core SIMD-accelerated operations: dot product, matrix-vector multiply,
2//! and activation functions.
3//!
4//! These are the hottest primitives across SSM, ESN, and attention/neural
5//! forward passes. AVX2 processes 4 `f64` values per cycle, giving up to
6//! ~4x throughput on aligned inner loops.
7//!
8//! # Architecture
9//!
10//! ```text
11//! Public API (safe)            Internal dispatch
12//! ─────────────────            ─────────────────
13//! simd_dot(a, b)        ──►   avx2::dot_avx2      (x86_64 + AVX2 detected)
14//!                       └──►  dot_scalar           (fallback)
15//!
16//! simd_mat_vec(w,x,..)  ──►   avx2::mat_vec_avx2  (x86_64 + AVX2 detected)
17//!                       └──►  mat_vec_scalar       (fallback)
18//!
19//! simd_tanh(in, out)    ──►   avx2::tanh_avx2     (x86_64 + AVX2, Padé [2,2])
20//!                       └──►  tanh_scalar          (fallback)
21//!
22//! simd_exp(in, out)     ──►   avx2::exp_avx2      (x86_64 + AVX2, range-reduced deg-5)
23//!                       └──►  exp_scalar           (fallback)
24//!
25//! simd_sigmoid(in, out) ──►   avx2::sigmoid_avx2  (x86_64 + AVX2, via exp)
26//!                       └──►  sigmoid_scalar       (fallback)
27//!
28//! simd_silu(in, out)    ──►   avx2::silu_avx2     (x86_64 + AVX2, via sigmoid)
29//!                       └──►  silu_scalar          (fallback)
30//! ```
31
32// Runtime detection macro — only available with simd-avx2 feature (implies std on x86_64).
33#[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
34use std::is_x86_feature_detected;
35
36// ---------------------------------------------------------------------------
37// Scalar fallbacks (always available, no_std compatible)
38// ---------------------------------------------------------------------------
39
40/// Scalar dot product of two slices.
41///
42/// Computes `sum(a[i] * b[i])` for `i` in `0..min(a.len(), b.len())`.
43#[inline]
44fn dot_scalar(a: &[f64], b: &[f64]) -> f64 {
45    let n = a.len().min(b.len());
46    let mut sum = 0.0;
47    for i in 0..n {
48        sum += a[i] * b[i];
49    }
50    sum
51}
52
53/// Scalar matrix-vector multiply: `out[i] = dot(w[i*cols..], x)`.
54///
55/// `w` is a `rows x cols` row-major matrix, `x` is a `cols`-vector,
56/// `out` is a `rows`-vector (must be pre-allocated).
57#[inline]
58fn mat_vec_scalar(w: &[f64], x: &[f64], _rows: usize, cols: usize, out: &mut [f64]) {
59    for (row, out_i) in out.iter_mut().enumerate() {
60        let start = row * cols;
61        let mut sum = 0.0;
62        for j in 0..cols {
63            sum += w[start + j] * x[j];
64        }
65        *out_i = sum;
66    }
67}
68
69/// Scalar tanh: delegates to `crate::math::tanh`.
70#[inline]
71fn tanh_scalar(input: &[f64], output: &mut [f64]) {
72    for (i, &x) in input.iter().enumerate() {
73        output[i] = crate::math::tanh(x);
74    }
75}
76
77/// Scalar exp: delegates to `crate::math::exp`.
78#[inline]
79fn exp_scalar(input: &[f64], output: &mut [f64]) {
80    for (i, &x) in input.iter().enumerate() {
81        output[i] = crate::math::exp(x);
82    }
83}
84
85/// Scalar sigmoid: delegates to `crate::math::sigmoid`.
86#[inline]
87fn sigmoid_scalar(input: &[f64], output: &mut [f64]) {
88    for (i, &x) in input.iter().enumerate() {
89        output[i] = crate::math::sigmoid(x);
90    }
91}
92
93/// Scalar SiLU: `output[i] = input[i] * sigmoid(input[i])`.
94#[inline]
95fn silu_scalar(input: &[f64], output: &mut [f64]) {
96    for (i, &x) in input.iter().enumerate() {
97        output[i] = x * crate::math::sigmoid(x);
98    }
99}
100
101// ---------------------------------------------------------------------------
102// AVX2 implementations (x86_64 + std only)
103// ---------------------------------------------------------------------------
104
105#[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
106mod avx2 {
107    /// AVX2-accelerated dot product: processes 4 f64 values per iteration.
108    ///
109    /// # Safety
110    ///
111    /// Caller must ensure AVX2 is available at runtime (checked via
112    /// `is_x86_feature_detected!("avx2")`).
113    #[target_feature(enable = "avx2")]
114    pub(super) unsafe fn dot_avx2(a: &[f64], b: &[f64]) -> f64 {
115        #[cfg(target_arch = "x86_64")]
116        use core::arch::x86_64::*;
117
118        let n = a.len().min(b.len());
119        let chunks = n / 4;
120        let remainder = n % 4;
121
122        let a_ptr = a.as_ptr();
123        let b_ptr = b.as_ptr();
124
125        // SAFETY: AVX2 availability verified by caller. All pointer arithmetic
126        // stays within slice bounds (chunks * 4 <= n).
127        unsafe {
128            let mut acc = _mm256_setzero_pd();
129
130            for i in 0..chunks {
131                let offset = i * 4;
132                let va = _mm256_loadu_pd(a_ptr.add(offset));
133                let vb = _mm256_loadu_pd(b_ptr.add(offset));
134                acc = _mm256_add_pd(acc, _mm256_mul_pd(va, vb));
135            }
136
137            // Horizontal sum of 4 f64 lanes: [a0, a1, a2, a3]
138            let hi128 = _mm256_extractf128_pd(acc, 1); // [a2, a3]
139            let lo128 = _mm256_castpd256_pd128(acc); // [a0, a1]
140            let pair = _mm_add_pd(lo128, hi128); // [a0+a2, a1+a3]
141            let high64 = _mm_unpackhi_pd(pair, pair); // [a1+a3, a1+a3]
142            let total = _mm_add_sd(pair, high64); // low lane = a0+a1+a2+a3
143            let mut scalar_sum = _mm_cvtsd_f64(total);
144
145            // Handle remainder with scalar tail.
146            let base = chunks * 4;
147            for i in 0..remainder {
148                scalar_sum += *a_ptr.add(base + i) * *b_ptr.add(base + i);
149            }
150
151            scalar_sum
152        }
153    }
154
155    /// AVX2-accelerated matrix-vector multiply.
156    ///
157    /// Each row is computed as a SIMD dot product of `w[row*cols..]` with `x`.
158    ///
159    /// # Safety
160    ///
161    /// Caller must ensure:
162    /// - AVX2 is available at runtime
163    /// - `w.len() >= rows * cols`, `x.len() >= cols`, `out.len() >= rows`
164    #[target_feature(enable = "avx2")]
165    pub(super) unsafe fn mat_vec_avx2(
166        w: &[f64],
167        x: &[f64],
168        _rows: usize,
169        cols: usize,
170        out: &mut [f64],
171    ) {
172        for (row, out_i) in out.iter_mut().enumerate() {
173            let row_start = row * cols;
174            // SAFETY: caller ensures w has at least rows*cols elements.
175            // dot_avx2 uses min(a.len(), b.len()) so slicing is safe.
176            unsafe {
177                *out_i = dot_avx2(&w[row_start..row_start + cols], &x[..cols]);
178            }
179        }
180    }
181
182    /// AVX2-accelerated tanh using Padé \[2,2\] rational approximation.
183    ///
184    /// For |x| > 4.97 the result saturates to ±1.0.
185    /// Otherwise: `tanh(x) ≈ x * (27 + x²) / (27 + 9*x²)`.
186    ///
187    /// Processes 4 f64 values per iteration with a scalar tail for
188    /// non-multiple-of-4 lengths.
189    ///
190    /// # Safety
191    ///
192    /// Caller must ensure AVX2 is available at runtime.
193    #[target_feature(enable = "avx2")]
194    pub(super) unsafe fn tanh_avx2(input: &[f64], output: &mut [f64]) {
195        #[cfg(target_arch = "x86_64")]
196        use core::arch::x86_64::*;
197
198        let n = input.len();
199        let chunks = n / 4;
200
201        // SAFETY: AVX2 availability verified by caller. All pointer arithmetic
202        // stays within slice bounds.
203        unsafe {
204            let c15 = _mm256_set1_pd(15.0);
205            let c6 = _mm256_set1_pd(6.0);
206            let pos_sat = _mm256_set1_pd(4.97);
207            let neg_sat = _mm256_set1_pd(-4.97);
208            let one = _mm256_set1_pd(1.0);
209            let neg_one = _mm256_set1_pd(-1.0);
210
211            for i in 0..chunks {
212                let off = i * 4;
213                let x = _mm256_loadu_pd(input.as_ptr().add(off));
214                let x2 = _mm256_mul_pd(x, x);
215
216                // Padé [2,2]: x * (15 + x²) / (15 + 6*x²)
217                let numer = _mm256_mul_pd(x, _mm256_add_pd(c15, x2));
218                let denom = _mm256_add_pd(c15, _mm256_mul_pd(c6, x2));
219                let approx = _mm256_div_pd(numer, denom);
220
221                // Clamp to [-1, 1] for |x| > 4.97
222                let clamped = _mm256_min_pd(one, _mm256_max_pd(neg_one, approx));
223
224                // Use saturation mask: if |x| > 4.97, output sign(x)
225                let sat_pos = _mm256_cmp_pd(x, pos_sat, _CMP_GT_OQ);
226                let sat_neg = _mm256_cmp_pd(x, neg_sat, _CMP_LT_OQ);
227                let result = _mm256_blendv_pd(clamped, one, sat_pos);
228                let result = _mm256_blendv_pd(result, neg_one, sat_neg);
229
230                _mm256_storeu_pd(output.as_mut_ptr().add(off), result);
231            }
232        }
233
234        // Scalar tail for remainder elements.
235        for i in (chunks * 4)..n {
236            output[i] = crate::math::tanh(input[i]);
237        }
238    }
239
240    /// AVX2-accelerated exp using range reduction + degree-5 polynomial + 2^n scaling.
241    ///
242    /// Algorithm: clamp to [-708, 708], split x = n*ln2 + r where n = round(x/ln2),
243    /// compute exp(r) via Horner polynomial, then scale by 2^n via IEEE 754 exponent
244    /// manipulation.
245    ///
246    /// # Safety
247    ///
248    /// Caller must ensure AVX2 is available at runtime.
249    #[target_feature(enable = "avx2")]
250    pub(super) unsafe fn exp_avx2(input: &[f64], output: &mut [f64]) {
251        #[cfg(target_arch = "x86_64")]
252        use core::arch::x86_64::*;
253
254        let n = input.len();
255        let chunks = n / 4;
256
257        unsafe {
258            let ln2 = _mm256_set1_pd(core::f64::consts::LN_2);
259            let log2e = _mm256_set1_pd(core::f64::consts::LOG2_E);
260            let clamp_hi = _mm256_set1_pd(708.0);
261            let clamp_lo = _mm256_set1_pd(-708.0);
262            let one = _mm256_set1_pd(1.0);
263            let half = _mm256_set1_pd(0.5);
264            let c3 = _mm256_set1_pd(1.0 / 6.0);
265            let c4 = _mm256_set1_pd(1.0 / 24.0);
266            let c5 = _mm256_set1_pd(1.0 / 120.0);
267            let bias = _mm256_set1_epi64x(1023);
268
269            for i in 0..chunks {
270                let off = i * 4;
271                let x = _mm256_loadu_pd(input.as_ptr().add(off));
272
273                // Clamp to prevent overflow/underflow
274                let x = _mm256_min_pd(clamp_hi, _mm256_max_pd(clamp_lo, x));
275
276                // Range reduction: n = round(x / ln2), r = x - n*ln2
277                let x_scaled = _mm256_mul_pd(x, log2e);
278                let n_f = _mm256_floor_pd(_mm256_add_pd(x_scaled, half));
279                let r = _mm256_sub_pd(x, _mm256_mul_pd(n_f, ln2));
280
281                // Horner polynomial: 1 + r*(1 + r*(0.5 + r*(1/6 + r*(1/24 + r/120))))
282                let mut p = _mm256_add_pd(c4, _mm256_mul_pd(c5, r));
283                p = _mm256_add_pd(c3, _mm256_mul_pd(p, r));
284                p = _mm256_add_pd(half, _mm256_mul_pd(p, r));
285                p = _mm256_add_pd(one, _mm256_mul_pd(p, r));
286                p = _mm256_add_pd(one, _mm256_mul_pd(p, r));
287
288                // Scale by 2^n via IEEE 754 exponent manipulation
289                let n_i32 = _mm256_cvtpd_epi32(n_f);
290                let n_i64 = _mm256_cvtepi32_epi64(n_i32);
291                let shifted = _mm256_slli_epi64(_mm256_add_epi64(n_i64, bias), 52);
292                let pow2n = _mm256_castsi256_pd(shifted);
293                let result = _mm256_mul_pd(p, pow2n);
294
295                _mm256_storeu_pd(output.as_mut_ptr().add(off), result);
296            }
297        }
298
299        // Scalar tail
300        for i in (chunks * 4)..n {
301            output[i] = crate::math::exp(input[i]);
302        }
303    }
304
305    /// AVX2-accelerated sigmoid: `1 / (1 + exp(-x))`.
306    ///
307    /// Computes exp(-x) inline using the same range-reduction polynomial as
308    /// `exp_avx2`, then applies the sigmoid formula.
309    ///
310    /// # Safety
311    ///
312    /// Caller must ensure AVX2 is available at runtime.
313    #[target_feature(enable = "avx2")]
314    pub(super) unsafe fn sigmoid_avx2(input: &[f64], output: &mut [f64]) {
315        #[cfg(target_arch = "x86_64")]
316        use core::arch::x86_64::*;
317
318        let n = input.len();
319        let chunks = n / 4;
320
321        unsafe {
322            let ln2 = _mm256_set1_pd(core::f64::consts::LN_2);
323            let log2e = _mm256_set1_pd(core::f64::consts::LOG2_E);
324            let clamp_hi = _mm256_set1_pd(708.0);
325            let clamp_lo = _mm256_set1_pd(-708.0);
326            let one = _mm256_set1_pd(1.0);
327            let half = _mm256_set1_pd(0.5);
328            let c3 = _mm256_set1_pd(1.0 / 6.0);
329            let c4 = _mm256_set1_pd(1.0 / 24.0);
330            let c5 = _mm256_set1_pd(1.0 / 120.0);
331            let bias = _mm256_set1_epi64x(1023);
332            let neg_one = _mm256_set1_pd(-1.0);
333
334            for i in 0..chunks {
335                let off = i * 4;
336                let x = _mm256_loadu_pd(input.as_ptr().add(off));
337
338                // Negate x for exp(-x), then clamp
339                let neg_x = _mm256_mul_pd(x, neg_one);
340                let neg_x = _mm256_min_pd(clamp_hi, _mm256_max_pd(clamp_lo, neg_x));
341
342                // Range reduction
343                let x_scaled = _mm256_mul_pd(neg_x, log2e);
344                let n_f = _mm256_floor_pd(_mm256_add_pd(x_scaled, half));
345                let r = _mm256_sub_pd(neg_x, _mm256_mul_pd(n_f, ln2));
346
347                // Horner polynomial for exp(r)
348                let mut p = _mm256_add_pd(c4, _mm256_mul_pd(c5, r));
349                p = _mm256_add_pd(c3, _mm256_mul_pd(p, r));
350                p = _mm256_add_pd(half, _mm256_mul_pd(p, r));
351                p = _mm256_add_pd(one, _mm256_mul_pd(p, r));
352                p = _mm256_add_pd(one, _mm256_mul_pd(p, r));
353
354                // Scale by 2^n
355                let n_i32 = _mm256_cvtpd_epi32(n_f);
356                let n_i64 = _mm256_cvtepi32_epi64(n_i32);
357                let shifted = _mm256_slli_epi64(_mm256_add_epi64(n_i64, bias), 52);
358                let pow2n = _mm256_castsi256_pd(shifted);
359                let exp_neg_x = _mm256_mul_pd(p, pow2n);
360
361                // sigmoid = 1 / (1 + exp(-x))
362                let result = _mm256_div_pd(one, _mm256_add_pd(one, exp_neg_x));
363
364                _mm256_storeu_pd(output.as_mut_ptr().add(off), result);
365            }
366        }
367
368        // Scalar tail
369        for i in (chunks * 4)..n {
370            output[i] = crate::math::sigmoid(input[i]);
371        }
372    }
373
374    /// AVX2-accelerated SiLU: `x * sigmoid(x)`.
375    ///
376    /// Computes sigmoid via `sigmoid_avx2`, then multiplies element-wise by input.
377    ///
378    /// # Safety
379    ///
380    /// Caller must ensure AVX2 is available at runtime.
381    #[target_feature(enable = "avx2")]
382    pub(super) unsafe fn silu_avx2(input: &[f64], output: &mut [f64]) {
383        #[cfg(target_arch = "x86_64")]
384        use core::arch::x86_64::*;
385
386        // First compute sigmoid into output
387        unsafe {
388            sigmoid_avx2(input, output);
389        }
390
391        // Then multiply by input: output[i] = input[i] * sigmoid(input[i])
392        let n = input.len();
393        let chunks = n / 4;
394        unsafe {
395            for i in 0..chunks {
396                let off = i * 4;
397                let x = _mm256_loadu_pd(input.as_ptr().add(off));
398                let sig = _mm256_loadu_pd(output.as_ptr().add(off));
399                _mm256_storeu_pd(output.as_mut_ptr().add(off), _mm256_mul_pd(x, sig));
400            }
401        }
402        // Scalar tail: output already has sigmoid from scalar tail in sigmoid_avx2
403        for i in (chunks * 4)..n {
404            output[i] *= input[i];
405        }
406    }
407}
408
409// ---------------------------------------------------------------------------
410// Public safe dispatch functions
411// ---------------------------------------------------------------------------
412
413/// SIMD-accelerated dot product with runtime feature detection.
414///
415/// Uses AVX2 on x86_64 (with `std` feature) when available, falls back to
416/// scalar otherwise.
417///
418/// Returns the dot product of `a` and `b`, processing up to the shorter
419/// slice's length.
420///
421/// # Examples
422///
423/// ```
424/// use irithyll_core::simd::simd_dot;
425///
426/// let a = [1.0, 2.0, 3.0];
427/// let b = [4.0, 5.0, 6.0];
428/// assert!((simd_dot(&a, &b) - 32.0).abs() < 1e-12);
429/// ```
430pub fn simd_dot(a: &[f64], b: &[f64]) -> f64 {
431    #[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
432    {
433        if is_x86_feature_detected!("avx2") {
434            // SAFETY: we just checked for AVX2 support.
435            return unsafe { avx2::dot_avx2(a, b) };
436        }
437    }
438    dot_scalar(a, b)
439}
440
441/// SIMD-accelerated matrix-vector multiply with runtime feature detection.
442///
443/// Computes `out[i] = sum_j w[i*cols + j] * x[j]` for each row.
444/// Uses AVX2 on x86_64 (with `std` feature) when available, falls back to
445/// scalar otherwise.
446///
447/// # Panics
448///
449/// Panics if `w.len() < rows * cols`, `out.len() < rows`, or `x.len() < cols`.
450///
451/// # Examples
452///
453/// ```
454/// use irithyll_core::simd::simd_mat_vec;
455///
456/// // 2x3 matrix times 3-vector
457/// let w = [1.0, 2.0, 3.0,  4.0, 5.0, 6.0];
458/// let x = [1.0, 1.0, 1.0];
459/// let mut out = [0.0; 2];
460/// simd_mat_vec(&w, &x, 2, 3, &mut out);
461/// assert!((out[0] - 6.0).abs() < 1e-12);   // 1+2+3
462/// assert!((out[1] - 15.0).abs() < 1e-12);  // 4+5+6
463/// ```
464pub fn simd_mat_vec(w: &[f64], x: &[f64], rows: usize, cols: usize, out: &mut [f64]) {
465    assert!(
466        w.len() >= rows * cols,
467        "simd_mat_vec: w.len()={} < rows*cols={}",
468        w.len(),
469        rows * cols
470    );
471    assert!(
472        out.len() >= rows,
473        "simd_mat_vec: out.len()={} < rows={}",
474        out.len(),
475        rows
476    );
477    assert!(
478        x.len() >= cols,
479        "simd_mat_vec: x.len()={} < cols={}",
480        x.len(),
481        cols
482    );
483
484    #[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
485    {
486        if is_x86_feature_detected!("avx2") {
487            // SAFETY: bounds checked above, AVX2 detected.
488            unsafe {
489                avx2::mat_vec_avx2(w, x, rows, cols, out);
490            }
491            return;
492        }
493    }
494    mat_vec_scalar(w, x, rows, cols, out);
495}
496
497/// SIMD-accelerated element-wise tanh with runtime feature detection.
498///
499/// Uses an AVX2 Padé \[2,2\] rational approximation on x86_64 (with `std`
500/// feature) when available, falls back to `crate::math::tanh` otherwise.
501///
502/// # Panics
503///
504/// Panics if `output.len() < input.len()`.
505///
506/// # Examples
507///
508/// ```
509/// use irithyll_core::simd::simd_tanh;
510///
511/// let input = [0.0, 1.0, -1.0];
512/// let mut output = [0.0; 3];
513/// simd_tanh(&input, &mut output);
514/// assert!(output[0].abs() < 1e-10);
515/// assert!((output[1] - 0.7616).abs() < 0.01);
516/// ```
517pub fn simd_tanh(input: &[f64], output: &mut [f64]) {
518    assert!(
519        output.len() >= input.len(),
520        "simd_tanh: output.len()={} < input.len()={}",
521        output.len(),
522        input.len()
523    );
524    #[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
525    {
526        if is_x86_feature_detected!("avx2") {
527            // SAFETY: bounds checked above, AVX2 detected.
528            unsafe {
529                avx2::tanh_avx2(input, output);
530            }
531            return;
532        }
533    }
534    tanh_scalar(input, output);
535}
536
537/// SIMD-accelerated element-wise exp with runtime feature detection.
538///
539/// Uses AVX2 range-reduction + degree-5 polynomial on x86_64 (with `std`
540/// feature) when available, falls back to `crate::math::exp` otherwise.
541///
542/// # Panics
543///
544/// Panics if `output.len() < input.len()`.
545///
546/// # Examples
547///
548/// ```
549/// use irithyll_core::simd::simd_exp;
550///
551/// let input = [0.0, 1.0];
552/// let mut output = [0.0; 2];
553/// simd_exp(&input, &mut output);
554/// assert!((output[0] - 1.0).abs() < 1e-10);
555/// assert!((output[1] - core::f64::consts::E).abs() < 1e-10);
556/// ```
557pub fn simd_exp(input: &[f64], output: &mut [f64]) {
558    assert!(
559        output.len() >= input.len(),
560        "simd_exp: output.len()={} < input.len()={}",
561        output.len(),
562        input.len()
563    );
564    #[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
565    {
566        if is_x86_feature_detected!("avx2") {
567            // SAFETY: bounds checked above, AVX2 detected.
568            unsafe {
569                avx2::exp_avx2(input, output);
570            }
571            return;
572        }
573    }
574    exp_scalar(input, output);
575}
576
577/// SIMD-accelerated element-wise sigmoid with runtime feature detection.
578///
579/// Uses AVX2 vectorized exp on x86_64 (with `std` feature) when available,
580/// falls back to `crate::math::sigmoid` otherwise.
581///
582/// # Panics
583///
584/// Panics if `output.len() < input.len()`.
585///
586/// # Examples
587///
588/// ```
589/// use irithyll_core::simd::simd_sigmoid;
590///
591/// let input = [0.0];
592/// let mut output = [0.0; 1];
593/// simd_sigmoid(&input, &mut output);
594/// assert!((output[0] - 0.5).abs() < 1e-10);
595/// ```
596pub fn simd_sigmoid(input: &[f64], output: &mut [f64]) {
597    assert!(
598        output.len() >= input.len(),
599        "simd_sigmoid: output.len()={} < input.len()={}",
600        output.len(),
601        input.len()
602    );
603    #[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
604    {
605        if is_x86_feature_detected!("avx2") {
606            // SAFETY: bounds checked above, AVX2 detected.
607            unsafe {
608                avx2::sigmoid_avx2(input, output);
609            }
610            return;
611        }
612    }
613    sigmoid_scalar(input, output);
614}
615
616/// SIMD-accelerated element-wise SiLU (Sigmoid Linear Unit) with runtime
617/// feature detection.
618///
619/// Computes `output[i] = input[i] * sigmoid(input[i])` for each element.
620/// Uses AVX2 vectorized sigmoid on x86_64 (with `std` feature) when
621/// available, falls back to scalar otherwise.
622///
623/// # Panics
624///
625/// Panics if `output.len() < input.len()`.
626///
627/// # Examples
628///
629/// ```
630/// use irithyll_core::simd::simd_silu;
631///
632/// let input = [0.0];
633/// let mut output = [0.0; 1];
634/// simd_silu(&input, &mut output);
635/// assert!(output[0].abs() < 1e-10); // silu(0) = 0 * 0.5 = 0
636/// ```
637pub fn simd_silu(input: &[f64], output: &mut [f64]) {
638    assert!(
639        output.len() >= input.len(),
640        "simd_silu: output.len()={} < input.len()={}",
641        output.len(),
642        input.len()
643    );
644    #[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
645    {
646        if is_x86_feature_detected!("avx2") {
647            // SAFETY: bounds checked above, AVX2 detected.
648            unsafe {
649                avx2::silu_avx2(input, output);
650            }
651            return;
652        }
653    }
654    silu_scalar(input, output);
655}
656
657// ===========================================================================
658// Tests
659// ===========================================================================
660
661#[cfg(test)]
662mod tests {
663    use super::*;
664    use alloc::vec;
665    use alloc::vec::Vec;
666
667    // Simple deterministic PRNG for test data generation.
668    struct TestRng(u64);
669
670    impl TestRng {
671        fn new(seed: u64) -> Self {
672            Self(seed)
673        }
674
675        fn next_u64(&mut self) -> u64 {
676            let mut x = self.0;
677            x ^= x << 13;
678            x ^= x >> 7;
679            x ^= x << 17;
680            self.0 = x;
681            x
682        }
683
684        fn next_f64(&mut self) -> f64 {
685            // Map to [-1, 1) range for interesting test values
686            (self.next_u64() >> 11) as f64 / ((1u64 << 53) as f64) * 2.0 - 1.0
687        }
688
689        fn fill_vec(&mut self, n: usize) -> Vec<f64> {
690            (0..n).map(|_| self.next_f64()).collect()
691        }
692    }
693
694    // -------------------------------------------------------------------
695    // Dot product tests
696    // -------------------------------------------------------------------
697
698    #[test]
699    fn dot_empty_returns_zero() {
700        let a: [f64; 0] = [];
701        let b: [f64; 0] = [];
702        assert_eq!(simd_dot(&a, &b), 0.0, "dot of empty slices should be 0");
703    }
704
705    #[test]
706    fn dot_single_element() {
707        let a = [3.0];
708        let b = [4.0];
709        assert!(
710            (simd_dot(&a, &b) - 12.0).abs() < 1e-12,
711            "dot([3], [4]) should be 12, got {}",
712            simd_dot(&a, &b)
713        );
714    }
715
716    #[test]
717    fn dot_known_result() {
718        let a = [1.0, 2.0, 3.0];
719        let b = [4.0, 5.0, 6.0];
720        let result = simd_dot(&a, &b);
721        assert!(
722            (result - 32.0).abs() < 1e-12,
723            "dot([1,2,3], [4,5,6]) should be 32, got {}",
724            result
725        );
726    }
727
728    #[test]
729    fn dot_large_matches_scalar() {
730        let mut rng = TestRng::new(42);
731        let a = rng.fill_vec(1000);
732        let b = rng.fill_vec(1000);
733
734        let simd_result = simd_dot(&a, &b);
735        let scalar_result = dot_scalar(&a, &b);
736
737        assert!(
738            (simd_result - scalar_result).abs() < 1e-9,
739            "1000-element dot: SIMD={} vs scalar={}, diff={}",
740            simd_result,
741            scalar_result,
742            (simd_result - scalar_result).abs()
743        );
744    }
745
746    #[test]
747    fn dot_mismatched_lengths() {
748        // Should use the shorter length
749        let a = [1.0, 2.0, 3.0, 999.0];
750        let b = [4.0, 5.0, 6.0];
751        let result = simd_dot(&a, &b);
752        assert!(
753            (result - 32.0).abs() < 1e-12,
754            "mismatched lengths should use min, expected 32, got {}",
755            result
756        );
757    }
758
759    #[test]
760    fn dot_non_aligned_length() {
761        // 7 elements: 1 full AVX2 chunk (4) + 3 remainder
762        let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
763        let b = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
764        let result = simd_dot(&a, &b);
765        assert!(
766            (result - 28.0).abs() < 1e-12,
767            "dot of [1..7] with [1..1] should be 28, got {}",
768            result
769        );
770    }
771
772    #[test]
773    fn dot_negative_values() {
774        let a = [-1.0, -2.0, -3.0, -4.0];
775        let b = [4.0, 3.0, 2.0, 1.0];
776        // -4 + -6 + -6 + -4 = -20
777        let result = simd_dot(&a, &b);
778        assert!(
779            (result - (-20.0)).abs() < 1e-12,
780            "expected -20, got {}",
781            result
782        );
783    }
784
785    #[test]
786    fn dot_orthogonal_vectors() {
787        let a = [1.0, 0.0, 0.0, 0.0];
788        let b = [0.0, 1.0, 0.0, 0.0];
789        let result = simd_dot(&a, &b);
790        assert!(
791            result.abs() < 1e-12,
792            "orthogonal vectors should have dot=0, got {}",
793            result
794        );
795    }
796
797    // -------------------------------------------------------------------
798    // Matrix-vector multiply tests
799    // -------------------------------------------------------------------
800
801    #[test]
802    fn mat_vec_identity_like() {
803        // 3x3 identity matrix times [1, 2, 3] = [1, 2, 3]
804        let w = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
805        let x = [1.0, 2.0, 3.0];
806        let mut out = [0.0; 3];
807        simd_mat_vec(&w, &x, 3, 3, &mut out);
808        assert!(
809            (out[0] - 1.0).abs() < 1e-12,
810            "identity row 0: expected 1, got {}",
811            out[0]
812        );
813        assert!(
814            (out[1] - 2.0).abs() < 1e-12,
815            "identity row 1: expected 2, got {}",
816            out[1]
817        );
818        assert!(
819            (out[2] - 3.0).abs() < 1e-12,
820            "identity row 2: expected 3, got {}",
821            out[2]
822        );
823    }
824
825    #[test]
826    fn mat_vec_known_result() {
827        // 2x3 matrix:
828        // [1 2 3]   [1]   [1+4+9]   [14]
829        // [4 5 6] * [2] = [4+10+18] = [32]
830        //           [3]
831        let w = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
832        let x = [1.0, 2.0, 3.0];
833        let mut out = [0.0; 2];
834        simd_mat_vec(&w, &x, 2, 3, &mut out);
835        assert!(
836            (out[0] - 14.0).abs() < 1e-12,
837            "row 0: expected 14, got {}",
838            out[0]
839        );
840        assert!(
841            (out[1] - 32.0).abs() < 1e-12,
842            "row 1: expected 32, got {}",
843            out[1]
844        );
845    }
846
847    #[test]
848    fn mat_vec_large_matches_scalar() {
849        let mut rng = TestRng::new(7777);
850        let rows = 100;
851        let cols = 100;
852        let w = rng.fill_vec(rows * cols);
853        let x = rng.fill_vec(cols);
854        let mut out_simd = vec![0.0; rows];
855        let mut out_scalar = vec![0.0; rows];
856
857        simd_mat_vec(&w, &x, rows, cols, &mut out_simd);
858        mat_vec_scalar(&w, &x, rows, cols, &mut out_scalar);
859
860        for i in 0..rows {
861            assert!(
862                (out_simd[i] - out_scalar[i]).abs() < 1e-9,
863                "row {}: SIMD={} vs scalar={}, diff={}",
864                i,
865                out_simd[i],
866                out_scalar[i],
867                (out_simd[i] - out_scalar[i]).abs()
868            );
869        }
870    }
871
872    #[test]
873    fn mat_vec_single_row() {
874        // 1xN is just a dot product
875        let w = [1.0, 2.0, 3.0, 4.0, 5.0];
876        let x = [2.0, 2.0, 2.0, 2.0, 2.0];
877        let mut out = [0.0; 1];
878        simd_mat_vec(&w, &x, 1, 5, &mut out);
879        // 2+4+6+8+10 = 30
880        assert!(
881            (out[0] - 30.0).abs() < 1e-12,
882            "single-row mat_vec should be dot product, expected 30, got {}",
883            out[0]
884        );
885    }
886
887    #[test]
888    fn mat_vec_single_element() {
889        let w = [7.0];
890        let x = [3.0];
891        let mut out = [0.0; 1];
892        simd_mat_vec(&w, &x, 1, 1, &mut out);
893        assert!(
894            (out[0] - 21.0).abs() < 1e-12,
895            "1x1 mat_vec: 7*3=21, got {}",
896            out[0]
897        );
898    }
899
900    // -------------------------------------------------------------------
901    // Panic tests
902    // -------------------------------------------------------------------
903
904    #[test]
905    #[should_panic(expected = "simd_mat_vec: w.len()")]
906    fn mat_vec_panics_w_too_short() {
907        let w = [1.0, 2.0]; // need 2*3=6
908        let x = [1.0, 2.0, 3.0];
909        let mut out = [0.0; 2];
910        simd_mat_vec(&w, &x, 2, 3, &mut out);
911    }
912
913    #[test]
914    #[should_panic(expected = "simd_mat_vec: out.len()")]
915    fn mat_vec_panics_out_too_short() {
916        let w = [1.0; 6];
917        let x = [1.0; 3];
918        let mut out = [0.0; 1]; // need 2
919        simd_mat_vec(&w, &x, 2, 3, &mut out);
920    }
921
922    #[test]
923    #[should_panic(expected = "simd_mat_vec: x.len()")]
924    fn mat_vec_panics_x_too_short() {
925        let w = [1.0; 6];
926        let x = [1.0; 2]; // need 3
927        let mut out = [0.0; 2];
928        simd_mat_vec(&w, &x, 2, 3, &mut out);
929    }
930
931    // -------------------------------------------------------------------
932    // Platform-specific test
933    // -------------------------------------------------------------------
934
935    #[cfg(all(target_arch = "x86_64", feature = "std"))]
936    #[test]
937    fn simd_available_on_x86() {
938        // On modern x86_64, AVX2 should be available.
939        // This test verifies the runtime detection path doesn't panic.
940        let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
941        let b = [8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
942        let result = simd_dot(&a, &b);
943        // 8+14+18+20+20+18+14+8 = 120
944        assert!(
945            (result - 120.0).abs() < 1e-12,
946            "8-element dot product should be 120, got {}",
947            result
948        );
949
950        // AVX2 detection is platform-specific; skip in no_std test context.
951    }
952
953    // -------------------------------------------------------------------
954    // Activation function tests
955    // -------------------------------------------------------------------
956
957    #[test]
958    fn tanh_known_values() {
959        let input = [0.0, 1.0, -1.0, 5.0, -5.0, 0.5];
960        let mut output = [0.0; 6];
961        simd_tanh(&input, &mut output);
962        let expected = [0.0, 0.7616, -0.7616, 0.9999, -0.9999, 0.4621];
963        for (i, (&got, &exp)) in output.iter().zip(expected.iter()).enumerate() {
964            assert!(
965                (got - exp).abs() < 0.01,
966                "tanh[{i}]: expected ~{exp}, got {got}"
967            );
968        }
969    }
970
971    #[test]
972    fn tanh_matches_scalar() {
973        let mut rng = TestRng::new(42);
974        let input = rng.fill_vec(100);
975        let mut simd_out = vec![0.0; 100];
976        let mut scalar_out = vec![0.0; 100];
977        simd_tanh(&input, &mut simd_out);
978        for (i, &x) in input.iter().enumerate() {
979            scalar_out[i] = crate::math::tanh(x);
980        }
981        for i in 0..100 {
982            assert!(
983                (simd_out[i] - scalar_out[i]).abs() < 0.01,
984                "tanh[{i}]: SIMD={} vs scalar={}",
985                simd_out[i],
986                scalar_out[i]
987            );
988        }
989    }
990
991    #[test]
992    fn exp_known_values() {
993        let input = [0.0, 1.0, -1.0, 2.0, -2.0];
994        let mut output = [0.0; 5];
995        simd_exp(&input, &mut output);
996        let expected = [
997            1.0,
998            core::f64::consts::E,
999            1.0 / core::f64::consts::E,
1000            core::f64::consts::E * core::f64::consts::E,
1001            1.0 / (core::f64::consts::E * core::f64::consts::E),
1002        ];
1003        for (i, (&got, &exp)) in output.iter().zip(expected.iter()).enumerate() {
1004            let rel = (got - exp).abs() / exp.abs().max(1e-15);
1005            assert!(
1006                rel < 1e-5,
1007                "exp[{i}]: expected {exp}, got {got}, rel_err={rel}"
1008            );
1009        }
1010    }
1011
1012    #[test]
1013    fn exp_matches_scalar() {
1014        let mut rng = TestRng::new(99);
1015        // Generate values in range [-10, 10] for good coverage
1016        let input: Vec<f64> = (0..100).map(|_| rng.next_f64() * 10.0).collect();
1017        let mut simd_out = vec![0.0; 100];
1018        let mut scalar_out = vec![0.0; 100];
1019        simd_exp(&input, &mut simd_out);
1020        for (i, &x) in input.iter().enumerate() {
1021            scalar_out[i] = crate::math::exp(x);
1022        }
1023        for i in 0..100 {
1024            let rel = (simd_out[i] - scalar_out[i]).abs() / scalar_out[i].abs().max(1e-15);
1025            assert!(
1026                rel < 1e-5,
1027                "exp[{i}] (x={}): SIMD={} vs scalar={}, rel_err={}",
1028                input[i],
1029                simd_out[i],
1030                scalar_out[i],
1031                rel
1032            );
1033        }
1034    }
1035
1036    #[test]
1037    fn exp_extreme_values() {
1038        // Test clamping behavior at boundaries
1039        let input = [700.0, -700.0, 0.0, 100.0, -100.0];
1040        let mut output = [0.0; 5];
1041        simd_exp(&input, &mut output);
1042        // exp(700) is huge but finite
1043        assert!(output[0].is_finite(), "exp(700) should be finite");
1044        assert!(output[0] > 0.0, "exp(700) should be positive");
1045        // exp(-700) is tiny but positive
1046        assert!(output[1] > 0.0, "exp(-700) should be positive");
1047        assert!(output[1].is_finite(), "exp(-700) should be finite");
1048        // exp(0) = 1
1049        assert!((output[2] - 1.0).abs() < 1e-12, "exp(0) should be 1.0");
1050    }
1051
1052    #[test]
1053    fn sigmoid_known_values() {
1054        let input = [0.0, 10.0, -10.0, 1.0];
1055        let mut output = [0.0; 4];
1056        simd_sigmoid(&input, &mut output);
1057        assert!(
1058            (output[0] - 0.5).abs() < 0.01,
1059            "sigmoid(0) should be ~0.5, got {}",
1060            output[0]
1061        );
1062        assert!(
1063            output[1] > 0.99,
1064            "sigmoid(10) should be ~1.0, got {}",
1065            output[1]
1066        );
1067        assert!(
1068            output[2] < 0.01,
1069            "sigmoid(-10) should be ~0.0, got {}",
1070            output[2]
1071        );
1072    }
1073
1074    #[test]
1075    fn sigmoid_matches_scalar() {
1076        let mut rng = TestRng::new(123);
1077        let input: Vec<f64> = (0..100).map(|_| rng.next_f64() * 20.0 - 10.0).collect();
1078        let mut simd_out = vec![0.0; 100];
1079        let mut scalar_out = vec![0.0; 100];
1080        simd_sigmoid(&input, &mut simd_out);
1081        for (i, &x) in input.iter().enumerate() {
1082            scalar_out[i] = crate::math::sigmoid(x);
1083        }
1084        for i in 0..100 {
1085            assert!(
1086                (simd_out[i] - scalar_out[i]).abs() < 1e-6,
1087                "sigmoid[{i}] (x={}): SIMD={} vs scalar={}, diff={}",
1088                input[i],
1089                simd_out[i],
1090                scalar_out[i],
1091                (simd_out[i] - scalar_out[i]).abs()
1092            );
1093        }
1094    }
1095
1096    #[test]
1097    fn silu_known_values() {
1098        let input = [0.0, 1.0, -1.0, 3.0];
1099        let mut output = [0.0; 4];
1100        simd_silu(&input, &mut output);
1101        // silu(0) = 0 * 0.5 = 0
1102        assert!(
1103            output[0].abs() < 0.01,
1104            "silu(0) should be ~0, got {}",
1105            output[0]
1106        );
1107        // silu(1) = 1 * sigmoid(1) ~ 0.731
1108        assert!(
1109            (output[1] - 0.731).abs() < 0.01,
1110            "silu(1) should be ~0.731, got {}",
1111            output[1]
1112        );
1113    }
1114
1115    #[test]
1116    fn silu_matches_scalar() {
1117        let mut rng = TestRng::new(456);
1118        let input: Vec<f64> = (0..100).map(|_| rng.next_f64() * 10.0 - 5.0).collect();
1119        let mut simd_out = vec![0.0; 100];
1120        simd_silu(&input, &mut simd_out);
1121        for (i, &x) in input.iter().enumerate() {
1122            let expected = x * crate::math::sigmoid(x);
1123            assert!(
1124                (simd_out[i] - expected).abs() < 1e-6,
1125                "silu[{i}] (x={}): SIMD={} vs scalar={}, diff={}",
1126                x,
1127                simd_out[i],
1128                expected,
1129                (simd_out[i] - expected).abs()
1130            );
1131        }
1132    }
1133
1134    #[test]
1135    fn activations_handle_empty() {
1136        let input: [f64; 0] = [];
1137        let mut output: [f64; 0] = [];
1138        simd_tanh(&input, &mut output);
1139        simd_exp(&input, &mut output);
1140        simd_sigmoid(&input, &mut output);
1141        simd_silu(&input, &mut output);
1142    }
1143}