Skip to main content

candle_core/cpu/
mod.rs

1//! Traits and methods for CPU-backed Tensors
2
3pub mod erf;
4pub mod kernels;
5
6#[allow(unused)]
7trait Cpu<const ARR: usize> {
8    type Unit;
9    type Array;
10    const STEP: usize;
11    const EPR: usize;
12
13    fn n() -> usize;
14    unsafe fn zero() -> Self::Unit;
15    unsafe fn zero_array() -> Self::Array;
16    unsafe fn load(mem_addr: *const f32) -> Self::Unit;
17    unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
18    unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
19    unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
20    unsafe fn from_f32(v: f32) -> Self::Unit;
21    unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
22}
23
24#[allow(unused)]
25trait CpuF16<const ARR: usize> {
26    type Unit;
27    type Array;
28    const STEP: usize;
29    const EPR: usize;
30
31    fn n() -> usize;
32    unsafe fn zero() -> Self::Unit;
33    unsafe fn zero_array() -> Self::Array;
34    unsafe fn load(mem_addr: *const f16) -> Self::Unit;
35    unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
36    unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
37    unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
38    unsafe fn from_f32(v: f32) -> Self::Unit;
39    unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit);
40}
41
42#[allow(unused)]
43trait CpuBF16<const ARR: usize> {
44    type Unit;
45    type Array;
46    const STEP: usize;
47    const EPR: usize;
48
49    fn n() -> usize;
50    unsafe fn zero() -> Self::Unit;
51    unsafe fn zero_array() -> Self::Array;
52    unsafe fn load(mem_addr: *const bf16) -> Self::Unit;
53    unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
54    unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
55    unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
56    unsafe fn from_f32(v: f32) -> Self::Unit;
57    unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit);
58}
59
60use half::{bf16, f16};
61
62#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
63#[cfg(target_feature = "avx2")]
64pub mod avx;
65#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
66#[cfg(target_feature = "avx2")]
67pub use avx::{CurrentCpu, CurrentCpuBF16, CurrentCpuF16};
68
69#[cfg(target_arch = "wasm32")]
70#[cfg(target_feature = "simd128")]
71pub mod simd128;
72#[cfg(target_arch = "wasm32")]
73#[cfg(target_feature = "simd128")]
74pub use simd128::CurrentCpu;
75
76#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
77#[cfg(target_feature = "neon")]
78pub mod neon;
79#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
80#[cfg(target_feature = "neon")]
81pub use neon::CurrentCpu;
82
83#[cfg(any(
84    target_feature = "neon",
85    target_feature = "avx2",
86    target_feature = "simd128"
87))]
88#[inline(always)]
89pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
90    let np = k & !(CurrentCpu::STEP - 1);
91
92    let mut sum = CurrentCpu::zero_array();
93    let mut ax = CurrentCpu::zero_array();
94    let mut ay = CurrentCpu::zero_array();
95
96    for i in (0..np).step_by(CurrentCpu::STEP) {
97        for j in 0..CurrentCpu::n() {
98            ax[j] = CurrentCpu::load(a_row.add(i + j * CurrentCpu::EPR));
99            ay[j] = CurrentCpu::load(b_row.add(i + j * CurrentCpu::EPR));
100
101            sum[j] = CurrentCpu::vec_fma(sum[j], ax[j], ay[j]);
102        }
103    }
104
105    CurrentCpu::vec_reduce(sum, c);
106
107    // leftovers
108    for i in np..k {
109        *c += *a_row.add(i) * (*b_row.add(i));
110    }
111}
112
113#[cfg(not(any(
114    target_feature = "neon",
115    target_feature = "avx2",
116    target_feature = "simd128"
117)))]
118#[inline(always)]
119pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
120    // leftovers
121    for i in 0..k {
122        *c += *a_row.add(i) * (*b_row.add(i));
123    }
124}
125
126#[cfg(any(
127    target_feature = "neon",
128    target_feature = "avx2",
129    target_feature = "simd128"
130))]
131#[inline(always)]
132pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
133    let np = k & !(CurrentCpu::STEP - 1);
134
135    let mut sum = CurrentCpu::zero_array();
136    let mut x = CurrentCpu::zero_array();
137
138    for i in (0..np).step_by(CurrentCpu::STEP) {
139        for j in 0..CurrentCpu::n() {
140            x[j] = CurrentCpu::load(row.add(i + j * CurrentCpu::EPR));
141            sum[j] = CurrentCpu::vec_add(sum[j], x[j]);
142        }
143    }
144
145    CurrentCpu::vec_reduce(sum, b);
146
147    // leftovers
148    for i in np..k {
149        *b += *row.add(i)
150    }
151}
152
153#[cfg(not(any(
154    target_feature = "neon",
155    target_feature = "avx2",
156    target_feature = "simd128"
157)))]
158#[inline(always)]
159pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
160    *b = 0f32;
161    for i in 0..k {
162        *b += *row.add(i)
163    }
164}
165
166#[cfg(target_feature = "avx2")]
167#[inline(always)]
168pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
169    let mut sumf = 0.0f32;
170    let np = k & !(CurrentCpuF16::STEP - 1);
171
172    let mut sum = CurrentCpuF16::zero_array();
173    let mut ax = CurrentCpuF16::zero_array();
174    let mut ay = CurrentCpuF16::zero_array();
175
176    for i in (0..np).step_by(CurrentCpuF16::STEP) {
177        for j in 0..CurrentCpuF16::n() {
178            ax[j] = CurrentCpuF16::load(a_row.add(i + j * CurrentCpuF16::EPR));
179            ay[j] = CurrentCpuF16::load(b_row.add(i + j * CurrentCpuF16::EPR));
180
181            sum[j] = CurrentCpuF16::vec_fma(sum[j], ax[j], ay[j]);
182        }
183    }
184
185    CurrentCpuF16::vec_reduce(sum, &mut sumf);
186
187    // leftovers
188    for i in np..k {
189        sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
190    }
191    *c = sumf;
192}
193
194#[cfg(target_feature = "avx2")]
195#[inline(always)]
196pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
197    let mut sumf = 0.0f32;
198    let np = k & !(CurrentCpuBF16::STEP - 1);
199
200    let mut sum = CurrentCpuBF16::zero_array();
201    let mut ax = CurrentCpuBF16::zero_array();
202    let mut ay = CurrentCpuBF16::zero_array();
203
204    for i in (0..np).step_by(CurrentCpuBF16::STEP) {
205        for j in 0..CurrentCpuBF16::n() {
206            ax[j] = CurrentCpuBF16::load(a_row.add(i + j * CurrentCpuBF16::EPR));
207            ay[j] = CurrentCpuBF16::load(b_row.add(i + j * CurrentCpuBF16::EPR));
208
209            sum[j] = CurrentCpuBF16::vec_fma(sum[j], ax[j], ay[j]);
210        }
211    }
212
213    CurrentCpuBF16::vec_reduce(sum, &mut sumf);
214
215    // leftovers
216    for i in np..k {
217        sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
218    }
219    *c = sumf;
220}
221
222#[cfg(not(target_feature = "avx2"))]
223#[inline(always)]
224pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
225    // leftovers
226    let mut sum = 0.0;
227    for i in 0..k {
228        sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
229    }
230    *c = sum;
231}
232
233#[cfg(not(target_feature = "avx2"))]
234#[inline(always)]
235pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
236    // leftovers
237    let mut sum = 0.0;
238    for i in 0..k {
239        sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
240    }
241    *c = sum;
242}