1pub mod erf;
4pub mod kernels;
5
6#[allow(unused)]
7trait Cpu<const ARR: usize> {
8 type Unit;
9 type Array;
10 const STEP: usize;
11 const EPR: usize;
12
13 fn n() -> usize;
14 unsafe fn zero() -> Self::Unit;
15 unsafe fn zero_array() -> Self::Array;
16 unsafe fn load(mem_addr: *const f32) -> Self::Unit;
17 unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
18 unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
19 unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
20 unsafe fn from_f32(v: f32) -> Self::Unit;
21 unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
22}
23
24#[allow(unused)]
25trait CpuF16<const ARR: usize> {
26 type Unit;
27 type Array;
28 const STEP: usize;
29 const EPR: usize;
30
31 fn n() -> usize;
32 unsafe fn zero() -> Self::Unit;
33 unsafe fn zero_array() -> Self::Array;
34 unsafe fn load(mem_addr: *const f16) -> Self::Unit;
35 unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
36 unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
37 unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
38 unsafe fn from_f32(v: f32) -> Self::Unit;
39 unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit);
40}
41
42#[allow(unused)]
43trait CpuBF16<const ARR: usize> {
44 type Unit;
45 type Array;
46 const STEP: usize;
47 const EPR: usize;
48
49 fn n() -> usize;
50 unsafe fn zero() -> Self::Unit;
51 unsafe fn zero_array() -> Self::Array;
52 unsafe fn load(mem_addr: *const bf16) -> Self::Unit;
53 unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
54 unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
55 unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
56 unsafe fn from_f32(v: f32) -> Self::Unit;
57 unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit);
58}
59
60use half::{bf16, f16};
61
62#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
63#[cfg(target_feature = "avx2")]
64pub mod avx;
65#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
66#[cfg(target_feature = "avx2")]
67pub use avx::{CurrentCpu, CurrentCpuBF16, CurrentCpuF16};
68
69#[cfg(target_arch = "wasm32")]
70#[cfg(target_feature = "simd128")]
71pub mod simd128;
72#[cfg(target_arch = "wasm32")]
73#[cfg(target_feature = "simd128")]
74pub use simd128::CurrentCpu;
75
76#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
77#[cfg(target_feature = "neon")]
78pub mod neon;
79#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
80#[cfg(target_feature = "neon")]
81pub use neon::CurrentCpu;
82
83#[cfg(any(
84 target_feature = "neon",
85 target_feature = "avx2",
86 target_feature = "simd128"
87))]
88#[inline(always)]
89pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
90 let np = k & !(CurrentCpu::STEP - 1);
91
92 let mut sum = CurrentCpu::zero_array();
93 let mut ax = CurrentCpu::zero_array();
94 let mut ay = CurrentCpu::zero_array();
95
96 for i in (0..np).step_by(CurrentCpu::STEP) {
97 for j in 0..CurrentCpu::n() {
98 ax[j] = CurrentCpu::load(a_row.add(i + j * CurrentCpu::EPR));
99 ay[j] = CurrentCpu::load(b_row.add(i + j * CurrentCpu::EPR));
100
101 sum[j] = CurrentCpu::vec_fma(sum[j], ax[j], ay[j]);
102 }
103 }
104
105 CurrentCpu::vec_reduce(sum, c);
106
107 for i in np..k {
109 *c += *a_row.add(i) * (*b_row.add(i));
110 }
111}
112
113#[cfg(not(any(
114 target_feature = "neon",
115 target_feature = "avx2",
116 target_feature = "simd128"
117)))]
118#[inline(always)]
119pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
120 for i in 0..k {
122 *c += *a_row.add(i) * (*b_row.add(i));
123 }
124}
125
126#[cfg(any(
127 target_feature = "neon",
128 target_feature = "avx2",
129 target_feature = "simd128"
130))]
131#[inline(always)]
132pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
133 let np = k & !(CurrentCpu::STEP - 1);
134
135 let mut sum = CurrentCpu::zero_array();
136 let mut x = CurrentCpu::zero_array();
137
138 for i in (0..np).step_by(CurrentCpu::STEP) {
139 for j in 0..CurrentCpu::n() {
140 x[j] = CurrentCpu::load(row.add(i + j * CurrentCpu::EPR));
141 sum[j] = CurrentCpu::vec_add(sum[j], x[j]);
142 }
143 }
144
145 CurrentCpu::vec_reduce(sum, b);
146
147 for i in np..k {
149 *b += *row.add(i)
150 }
151}
152
153#[cfg(not(any(
154 target_feature = "neon",
155 target_feature = "avx2",
156 target_feature = "simd128"
157)))]
158#[inline(always)]
159pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
160 *b = 0f32;
161 for i in 0..k {
162 *b += *row.add(i)
163 }
164}
165
166#[cfg(target_feature = "avx2")]
167#[inline(always)]
168pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
169 let mut sumf = 0.0f32;
170 let np = k & !(CurrentCpuF16::STEP - 1);
171
172 let mut sum = CurrentCpuF16::zero_array();
173 let mut ax = CurrentCpuF16::zero_array();
174 let mut ay = CurrentCpuF16::zero_array();
175
176 for i in (0..np).step_by(CurrentCpuF16::STEP) {
177 for j in 0..CurrentCpuF16::n() {
178 ax[j] = CurrentCpuF16::load(a_row.add(i + j * CurrentCpuF16::EPR));
179 ay[j] = CurrentCpuF16::load(b_row.add(i + j * CurrentCpuF16::EPR));
180
181 sum[j] = CurrentCpuF16::vec_fma(sum[j], ax[j], ay[j]);
182 }
183 }
184
185 CurrentCpuF16::vec_reduce(sum, &mut sumf);
186
187 for i in np..k {
189 sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
190 }
191 *c = sumf;
192}
193
194#[cfg(target_feature = "avx2")]
195#[inline(always)]
196pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
197 let mut sumf = 0.0f32;
198 let np = k & !(CurrentCpuBF16::STEP - 1);
199
200 let mut sum = CurrentCpuBF16::zero_array();
201 let mut ax = CurrentCpuBF16::zero_array();
202 let mut ay = CurrentCpuBF16::zero_array();
203
204 for i in (0..np).step_by(CurrentCpuBF16::STEP) {
205 for j in 0..CurrentCpuBF16::n() {
206 ax[j] = CurrentCpuBF16::load(a_row.add(i + j * CurrentCpuBF16::EPR));
207 ay[j] = CurrentCpuBF16::load(b_row.add(i + j * CurrentCpuBF16::EPR));
208
209 sum[j] = CurrentCpuBF16::vec_fma(sum[j], ax[j], ay[j]);
210 }
211 }
212
213 CurrentCpuBF16::vec_reduce(sum, &mut sumf);
214
215 for i in np..k {
217 sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
218 }
219 *c = sumf;
220}
221
222#[cfg(not(target_feature = "avx2"))]
223#[inline(always)]
224pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
225 let mut sum = 0.0;
227 for i in 0..k {
228 sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
229 }
230 *c = sum;
231}
232
233#[cfg(not(target_feature = "avx2"))]
234#[inline(always)]
235pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
236 let mut sum = 0.0;
238 for i in 0..k {
239 sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
240 }
241 *c = sum;
242}