Skip to main content

trueno/backends/sse2/ops/
arithmetic.rs

1//! SSE2 arithmetic operations (add, sub, mul, div).
2
3#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6/// SSE2 vector addition.
7#[inline]
8#[target_feature(enable = "sse2")]
9// SAFETY: caller ensures preconditions are met for this unsafe function
10pub(crate) unsafe fn add(a: &[f32], b: &[f32], result: &mut [f32]) {
11    unsafe {
12        let len = a.len();
13        let mut i = 0;
14        while i + 4 <= len {
15            let va = _mm_loadu_ps(a.as_ptr().add(i));
16            let vb = _mm_loadu_ps(b.as_ptr().add(i));
17            _mm_storeu_ps(result.as_mut_ptr().add(i), _mm_add_ps(va, vb));
18            i += 4;
19        }
20        for j in i..len {
21            result[j] = a[j] + b[j];
22        }
23    }
24}
25
26/// SSE2 vector subtraction.
27#[inline]
28#[target_feature(enable = "sse2")]
29// SAFETY: caller ensures preconditions are met for this unsafe function
30pub(crate) unsafe fn sub(a: &[f32], b: &[f32], result: &mut [f32]) {
31    unsafe {
32        let len = a.len();
33        let mut i = 0;
34        while i + 4 <= len {
35            let va = _mm_loadu_ps(a.as_ptr().add(i));
36            let vb = _mm_loadu_ps(b.as_ptr().add(i));
37            _mm_storeu_ps(result.as_mut_ptr().add(i), _mm_sub_ps(va, vb));
38            i += 4;
39        }
40        for j in i..len {
41            result[j] = a[j] - b[j];
42        }
43    }
44}
45
46/// SSE2 vector multiplication.
47#[inline]
48#[target_feature(enable = "sse2")]
49// SAFETY: caller ensures preconditions are met for this unsafe function
50pub(crate) unsafe fn mul(a: &[f32], b: &[f32], result: &mut [f32]) {
51    unsafe {
52        let len = a.len();
53        let mut i = 0;
54        while i + 4 <= len {
55            let va = _mm_loadu_ps(a.as_ptr().add(i));
56            let vb = _mm_loadu_ps(b.as_ptr().add(i));
57            _mm_storeu_ps(result.as_mut_ptr().add(i), _mm_mul_ps(va, vb));
58            i += 4;
59        }
60        for j in i..len {
61            result[j] = a[j] * b[j];
62        }
63    }
64}
65
66/// SSE2 vector division.
67#[inline]
68#[target_feature(enable = "sse2")]
69// SAFETY: caller ensures preconditions are met for this unsafe function
70pub(crate) unsafe fn div(a: &[f32], b: &[f32], result: &mut [f32]) {
71    unsafe {
72        let len = a.len();
73        let mut i = 0;
74        while i + 4 <= len {
75            let va = _mm_loadu_ps(a.as_ptr().add(i));
76            let vb = _mm_loadu_ps(b.as_ptr().add(i));
77            _mm_storeu_ps(result.as_mut_ptr().add(i), _mm_div_ps(va, vb));
78            i += 4;
79        }
80        for j in i..len {
81            result[j] = a[j] / b[j];
82        }
83    }
84}