trueno/backends/sse2/ops/
reductions.rs1#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6use crate::backends::VectorBackend;
7
8#[inline]
10#[target_feature(enable = "sse2")]
11pub(crate) unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
13 unsafe {
14 let len = a.len();
15 let mut i = 0;
16 let mut acc = _mm_setzero_ps();
17
18 while i + 4 <= len {
19 let va = _mm_loadu_ps(a.as_ptr().add(i));
20 let vb = _mm_loadu_ps(b.as_ptr().add(i));
21 acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
22 i += 4;
23 }
24
25 let mut result = horizontal_sum_ps(acc);
26 result += a[i..].iter().zip(&b[i..]).map(|(x, y)| x * y).sum::<f32>();
27 result
28 }
29}
30
31#[inline]
33#[target_feature(enable = "sse2")]
34pub(crate) unsafe fn sum(a: &[f32]) -> f32 {
36 unsafe {
37 let len = a.len();
38 let mut i = 0;
39 let mut acc = _mm_setzero_ps();
40
41 while i + 4 <= len {
42 acc = _mm_add_ps(acc, _mm_loadu_ps(a.as_ptr().add(i)));
43 i += 4;
44 }
45
46 let mut result = horizontal_sum_ps(acc);
47 result += a[i..].iter().sum::<f32>();
48 result
49 }
50}
51
52#[inline]
54#[target_feature(enable = "sse2")]
55pub(crate) unsafe fn max(a: &[f32]) -> f32 {
57 unsafe {
58 let len = a.len();
59 let mut i = 0;
60 let mut vmax = _mm_set1_ps(a[0]);
61
62 while i + 4 <= len {
63 vmax = _mm_max_ps(vmax, _mm_loadu_ps(a.as_ptr().add(i)));
64 i += 4;
65 }
66
67 let mut result = horizontal_max_ps(vmax);
68 for &val in &a[i..] {
69 if val > result {
70 result = val;
71 }
72 }
73 result
74 }
75}
76
77#[inline]
79#[target_feature(enable = "sse2")]
80pub(crate) unsafe fn min(a: &[f32]) -> f32 {
82 unsafe {
83 let len = a.len();
84 let mut i = 0;
85 let mut vmin = _mm_set1_ps(a[0]);
86
87 while i + 4 <= len {
88 vmin = _mm_min_ps(vmin, _mm_loadu_ps(a.as_ptr().add(i)));
89 i += 4;
90 }
91
92 let mut result = horizontal_min_ps(vmin);
93 for &val in &a[i..] {
94 if val < result {
95 result = val;
96 }
97 }
98 result
99 }
100}
101
102#[inline]
104#[target_feature(enable = "sse2")]
105pub(crate) unsafe fn argmax(a: &[f32]) -> usize {
107 let _len = a.len();
108 let mut max_idx: usize = 0;
109 let mut max_val = a[0];
110
111 for (i, &val) in a.iter().enumerate() {
112 if val > max_val {
113 max_val = val;
114 max_idx = i;
115 }
116 }
117 max_idx
118}
119
120#[inline]
122#[target_feature(enable = "sse2")]
123pub(crate) unsafe fn argmin(a: &[f32]) -> usize {
125 let _len = a.len();
126 let mut min_idx: usize = 0;
127 let mut min_val = a[0];
128
129 for (i, &val) in a.iter().enumerate() {
130 if val < min_val {
131 min_val = val;
132 min_idx = i;
133 }
134 }
135 min_idx
136}
137
138#[inline]
140pub(crate) unsafe fn sum_kahan(a: &[f32]) -> f32 {
142 unsafe { crate::backends::scalar::ScalarBackend::sum_kahan(a) }
143}
144
145#[inline]
147#[target_feature(enable = "sse2")]
148unsafe fn horizontal_sum_ps(v: __m128) -> f32 {
150 let temp = _mm_add_ps(v, _mm_movehl_ps(v, v));
151 let temp = _mm_add_ss(temp, _mm_shuffle_ps(temp, temp, 1));
152 _mm_cvtss_f32(temp)
153}
154
155#[inline]
157#[target_feature(enable = "sse2")]
158unsafe fn horizontal_max_ps(v: __m128) -> f32 {
160 let temp = _mm_max_ps(v, _mm_movehl_ps(v, v));
161 let temp = _mm_max_ss(temp, _mm_shuffle_ps(temp, temp, 1));
162 _mm_cvtss_f32(temp)
163}
164
165#[inline]
167#[target_feature(enable = "sse2")]
168unsafe fn horizontal_min_ps(v: __m128) -> f32 {
170 let temp = _mm_min_ps(v, _mm_movehl_ps(v, v));
171 let temp = _mm_min_ss(temp, _mm_shuffle_ps(temp, temp, 1));
172 _mm_cvtss_f32(temp)
173}