trueno/backends/avx2/ops/
arithmetic.rs1#[cfg(target_arch = "x86_64")]
7use std::arch::x86_64::*;
8
9const NT_THRESHOLD: usize = 8192;
10
11#[inline]
13#[target_feature(enable = "avx2")]
14pub(crate) unsafe fn add(a: &[f32], b: &[f32], result: &mut [f32]) {
16 unsafe {
17 let len = a.len();
18 let rp_aligned = (result.as_ptr() as usize) % 32 == 0;
20 if len >= NT_THRESHOLD && rp_aligned {
21 add_nt(a, b, result);
22 } else {
23 add_cached(a, b, result);
24 }
25 }
26}
27
28#[inline]
29#[target_feature(enable = "avx2")]
30unsafe fn add_cached(a: &[f32], b: &[f32], result: &mut [f32]) {
31 unsafe {
32 let len = a.len();
33 let mut i = 0;
34 while i + 8 <= len {
35 let va = _mm256_loadu_ps(a.as_ptr().add(i));
36 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
37 _mm256_storeu_ps(result.as_mut_ptr().add(i), _mm256_add_ps(va, vb));
38 i += 8;
39 }
40 for j in i..len {
41 result[j] = a[j] + b[j];
42 }
43 }
44}
45
46#[inline]
47#[target_feature(enable = "avx2")]
48unsafe fn add_nt(a: &[f32], b: &[f32], result: &mut [f32]) {
49 unsafe {
50 let len = a.len();
51 let ap = a.as_ptr();
52 let bp = b.as_ptr();
53 let rp = result.as_mut_ptr();
54 let mut i = 0;
55
56 while i + 32 <= len {
58 _mm_prefetch(ap.add(i + 64).cast::<i8>(), _MM_HINT_T0);
59 _mm_prefetch(bp.add(i + 64).cast::<i8>(), _MM_HINT_T0);
60
61 let va0 = _mm256_loadu_ps(ap.add(i));
62 let vb0 = _mm256_loadu_ps(bp.add(i));
63 let va1 = _mm256_loadu_ps(ap.add(i + 8));
64 let vb1 = _mm256_loadu_ps(bp.add(i + 8));
65 let va2 = _mm256_loadu_ps(ap.add(i + 16));
66 let vb2 = _mm256_loadu_ps(bp.add(i + 16));
67 let va3 = _mm256_loadu_ps(ap.add(i + 24));
68 let vb3 = _mm256_loadu_ps(bp.add(i + 24));
69
70 _mm256_stream_ps(rp.add(i), _mm256_add_ps(va0, vb0));
71 _mm256_stream_ps(rp.add(i + 8), _mm256_add_ps(va1, vb1));
72 _mm256_stream_ps(rp.add(i + 16), _mm256_add_ps(va2, vb2));
73 _mm256_stream_ps(rp.add(i + 24), _mm256_add_ps(va3, vb3));
74
75 i += 32;
76 }
77
78 while i + 8 <= len {
79 let va = _mm256_loadu_ps(ap.add(i));
80 let vb = _mm256_loadu_ps(bp.add(i));
81 _mm256_stream_ps(rp.add(i), _mm256_add_ps(va, vb));
82 i += 8;
83 }
84
85 _mm_sfence();
86
87 for j in i..len {
88 result[j] = a[j] + b[j];
89 }
90 }
91}
92
93#[inline]
95#[target_feature(enable = "avx2")]
96pub(crate) unsafe fn sub(a: &[f32], b: &[f32], result: &mut [f32]) {
98 unsafe {
99 let len = a.len();
100 let ap = a.as_ptr();
101 let bp = b.as_ptr();
102 let rp = result.as_mut_ptr();
103 let mut i = 0;
104
105 let rp_aligned = (rp as usize) % 32 == 0;
107 if len >= NT_THRESHOLD && rp_aligned {
108 while i + 32 <= len {
109 _mm_prefetch(ap.add(i + 64).cast::<i8>(), _MM_HINT_T0);
110 _mm_prefetch(bp.add(i + 64).cast::<i8>(), _MM_HINT_T0);
111
112 _mm256_stream_ps(
113 rp.add(i),
114 _mm256_sub_ps(_mm256_loadu_ps(ap.add(i)), _mm256_loadu_ps(bp.add(i))),
115 );
116 _mm256_stream_ps(
117 rp.add(i + 8),
118 _mm256_sub_ps(_mm256_loadu_ps(ap.add(i + 8)), _mm256_loadu_ps(bp.add(i + 8))),
119 );
120 _mm256_stream_ps(
121 rp.add(i + 16),
122 _mm256_sub_ps(_mm256_loadu_ps(ap.add(i + 16)), _mm256_loadu_ps(bp.add(i + 16))),
123 );
124 _mm256_stream_ps(
125 rp.add(i + 24),
126 _mm256_sub_ps(_mm256_loadu_ps(ap.add(i + 24)), _mm256_loadu_ps(bp.add(i + 24))),
127 );
128 i += 32;
129 }
130 while i + 8 <= len {
131 _mm256_stream_ps(
132 rp.add(i),
133 _mm256_sub_ps(_mm256_loadu_ps(ap.add(i)), _mm256_loadu_ps(bp.add(i))),
134 );
135 i += 8;
136 }
137 _mm_sfence();
138 } else {
139 while i + 8 <= len {
140 _mm256_storeu_ps(
141 rp.add(i),
142 _mm256_sub_ps(_mm256_loadu_ps(ap.add(i)), _mm256_loadu_ps(bp.add(i))),
143 );
144 i += 8;
145 }
146 }
147
148 for j in i..len {
149 result[j] = a[j] - b[j];
150 }
151 }
152}
153
154#[inline]
156#[target_feature(enable = "avx2")]
157pub(crate) unsafe fn mul(a: &[f32], b: &[f32], result: &mut [f32]) {
159 unsafe {
160 let len = a.len();
161 let ap = a.as_ptr();
162 let bp = b.as_ptr();
163 let rp = result.as_mut_ptr();
164 let mut i = 0;
165
166 let rp_aligned = (rp as usize) % 32 == 0;
170 if len >= NT_THRESHOLD && rp_aligned {
171 while i + 32 <= len {
172 _mm_prefetch(ap.add(i + 64).cast::<i8>(), _MM_HINT_T0);
173 _mm_prefetch(bp.add(i + 64).cast::<i8>(), _MM_HINT_T0);
174
175 _mm256_stream_ps(
176 rp.add(i),
177 _mm256_mul_ps(_mm256_loadu_ps(ap.add(i)), _mm256_loadu_ps(bp.add(i))),
178 );
179 _mm256_stream_ps(
180 rp.add(i + 8),
181 _mm256_mul_ps(_mm256_loadu_ps(ap.add(i + 8)), _mm256_loadu_ps(bp.add(i + 8))),
182 );
183 _mm256_stream_ps(
184 rp.add(i + 16),
185 _mm256_mul_ps(_mm256_loadu_ps(ap.add(i + 16)), _mm256_loadu_ps(bp.add(i + 16))),
186 );
187 _mm256_stream_ps(
188 rp.add(i + 24),
189 _mm256_mul_ps(_mm256_loadu_ps(ap.add(i + 24)), _mm256_loadu_ps(bp.add(i + 24))),
190 );
191 i += 32;
192 }
193 while i + 8 <= len {
194 _mm256_stream_ps(
195 rp.add(i),
196 _mm256_mul_ps(_mm256_loadu_ps(ap.add(i)), _mm256_loadu_ps(bp.add(i))),
197 );
198 i += 8;
199 }
200 _mm_sfence();
201 } else {
202 while i + 8 <= len {
203 _mm256_storeu_ps(
204 rp.add(i),
205 _mm256_mul_ps(_mm256_loadu_ps(ap.add(i)), _mm256_loadu_ps(bp.add(i))),
206 );
207 i += 8;
208 }
209 }
210
211 for j in i..len {
212 result[j] = a[j] * b[j];
213 }
214 }
215}
216
217#[inline]
219#[target_feature(enable = "avx2")]
220pub(crate) unsafe fn div(a: &[f32], b: &[f32], result: &mut [f32]) {
222 unsafe {
223 let len = a.len();
224 let mut i = 0;
225 while i + 8 <= len {
226 let va = _mm256_loadu_ps(a.as_ptr().add(i));
227 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
228 _mm256_storeu_ps(result.as_mut_ptr().add(i), _mm256_div_ps(va, vb));
229 i += 8;
230 }
231 for j in i..len {
232 result[j] = a[j] / b[j];
233 }
234 }
235}