1use rayon::{ThreadPool, prelude::*};
2use yscv_tensor::{AlignedVec, Tensor};
3
4use super::super::error::KernelError;
5use super::config::{
6 BinaryKind, PARALLEL_SLICE_CHUNK_ELEMENTS, ParallelElementwiseConfig, should_parallelize_len,
7};
8use super::simd::{
9 binary_same_shape_dispatch, exp_slice_dispatch, relu_slice_dispatch, relu_to_slice_dispatch,
10 sigmoid_slice_dispatch, silu_slice_dispatch, tanh_slice_dispatch,
11};
12
13#[allow(unsafe_code)]
15mod par {
16 #[cfg(target_os = "macos")]
17 use std::ffi::c_void;
18
19 #[cfg(target_os = "macos")]
20 #[allow(unsafe_code)]
21 unsafe extern "C" {
22 fn dispatch_get_global_queue(identifier: isize, flags: usize) -> *const c_void;
23 fn dispatch_apply_f(
24 iterations: usize,
25 queue: *const c_void,
26 context: *mut c_void,
27 work: unsafe extern "C" fn(*mut c_void, usize),
28 );
29 }
30
31 #[cfg(target_os = "macos")]
32 #[inline]
33 #[allow(unsafe_code)]
34 pub fn parallel_for<F: Fn(usize) + Sync>(n: usize, f: F) {
35 #[allow(unsafe_code)]
36 unsafe extern "C" fn call<F: Fn(usize) + Sync>(ctx: *mut c_void, i: usize) {
37 unsafe {
38 (*(ctx as *const F))(i);
39 }
40 }
41 let queue = unsafe { dispatch_get_global_queue(0, 0) };
42 unsafe {
43 dispatch_apply_f(n, queue, &f as *const F as *mut c_void, call::<F>);
44 }
45 }
46
47 #[cfg(not(target_os = "macos"))]
48 #[inline]
49 pub fn parallel_for<F: Fn(usize) + Sync + Send>(n: usize, f: F) {
50 if n <= 1 {
51 for i in 0..n {
52 f(i);
53 }
54 return;
55 }
56 use rayon::prelude::*;
58 (0..n).into_par_iter().for_each(f);
59 }
60}
61
62#[inline]
64#[allow(unsafe_code)]
65pub fn relu(input: &Tensor) -> Tensor {
66 let input_data = input.data();
67 let len = input_data.len();
68 let mut output = AlignedVec::<f32>::uninitialized(len);
69
70 const PAR_THRESH: usize = 100_000;
71 if len >= PAR_THRESH {
72 let n_chunks = std::thread::available_parallelism()
73 .map(|p| p.get())
74 .unwrap_or(4);
75 let chunk = len.div_ceil(n_chunks);
76 let in_ptr = input_data.as_ptr() as usize;
77 let out_ptr = output.as_mut_ptr() as usize;
78 par::parallel_for(n_chunks, |t| {
79 let start = t * chunk;
80 let end = (start + chunk).min(len);
81 let inp = unsafe {
82 std::slice::from_raw_parts((in_ptr as *const f32).add(start), end - start)
83 };
84 let out = unsafe {
85 std::slice::from_raw_parts_mut((out_ptr as *mut f32).add(start), end - start)
86 };
87 relu_to_slice_dispatch(inp, out);
88 });
89 } else {
90 relu_to_slice_dispatch(input_data, &mut output);
91 }
92
93 Tensor::from_raw_parts(input.shape(), input.strides(), output)
94}
95
96#[inline]
98pub fn relu_inplace(tensor: &mut Tensor) {
99 relu_slice_dispatch(tensor.data_mut());
100}
101
102#[inline]
104pub fn relu_out(input: &Tensor, output: &mut Tensor) {
105 debug_assert_eq!(input.shape(), output.shape());
106 relu_to_slice_dispatch(input.data(), output.data_mut());
107}
108
109pub fn sigmoid(input: &Tensor) -> Tensor {
111 sigmoid_with_config(input, ParallelElementwiseConfig::disabled())
112}
113
114pub fn relu_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
116 relu_with_config_and_pool(input, config, None)
117}
118
119pub fn sigmoid_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
121 sigmoid_with_config_and_pool(input, config, None)
122}
123
124#[allow(unsafe_code)]
128pub fn relu_with_config_and_pool(
129 input: &Tensor,
130 config: ParallelElementwiseConfig,
131 thread_pool: Option<&ThreadPool>,
132) -> Tensor {
133 let input_data = input.data();
134 let len = input_data.len();
135 let mut output = AlignedVec::<f32>::uninitialized(len);
136 if should_parallelize_len(len, config.min_parallel_elements, thread_pool) {
137 let mut work = || {
138 output
139 .par_chunks_mut(PARALLEL_SLICE_CHUNK_ELEMENTS)
140 .enumerate()
141 .for_each(|(chunk_idx, out_chunk)| {
142 let start = chunk_idx * PARALLEL_SLICE_CHUNK_ELEMENTS;
143 let end = start + out_chunk.len();
144 relu_to_slice_dispatch(&input_data[start..end], out_chunk);
145 });
146 };
147 if let Some(pool) = thread_pool {
148 pool.install(work);
149 } else {
150 work();
151 }
152 } else {
153 relu_to_slice_dispatch(input_data, &mut output);
154 }
155 Tensor::from_raw_parts(input.shape(), input.strides(), output)
156}
157
158#[allow(unsafe_code)]
162pub fn sigmoid_with_config_and_pool(
163 input: &Tensor,
164 _config: ParallelElementwiseConfig,
165 _thread_pool: Option<&ThreadPool>,
166) -> Tensor {
167 let input_data = input.data();
168 let len = input_data.len();
169 let mut output = AlignedVec::<f32>::uninitialized(len);
173 sigmoid_slice_dispatch(input_data, &mut output);
174 Tensor::from_raw_parts(input.shape(), input.strides(), output)
175}
176
177pub fn exp(input: &Tensor) -> Tensor {
179 exp_with_config(input, ParallelElementwiseConfig::disabled())
180}
181
182pub fn exp_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
184 exp_with_config_and_pool(input, config, None)
185}
186
187#[allow(unsafe_code)]
191pub fn exp_with_config_and_pool(
192 input: &Tensor,
193 config: ParallelElementwiseConfig,
194 thread_pool: Option<&ThreadPool>,
195) -> Tensor {
196 let input_data = input.data();
197 let len = input_data.len();
198 let mut output = AlignedVec::<f32>::uninitialized(len);
199 if should_parallelize_len(len, config.min_parallel_elements, thread_pool) {
200 let mut work = || {
201 output
202 .par_chunks_mut(PARALLEL_SLICE_CHUNK_ELEMENTS)
203 .enumerate()
204 .for_each(|(chunk_idx, out_chunk)| {
205 let start = chunk_idx * PARALLEL_SLICE_CHUNK_ELEMENTS;
206 let end = start + out_chunk.len();
207 exp_slice_dispatch(&input_data[start..end], out_chunk);
208 });
209 };
210 if let Some(pool) = thread_pool {
211 pool.install(work);
212 } else {
213 work();
214 }
215 } else {
216 exp_slice_dispatch(input_data, &mut output);
217 }
218 Tensor::from_raw_parts(input.shape(), input.strides(), output)
219}
220
221pub fn tanh_act(input: &Tensor) -> Tensor {
223 tanh_act_with_config(input, ParallelElementwiseConfig::disabled())
224}
225
226pub fn tanh_act_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
228 tanh_act_with_config_and_pool(input, config, None)
229}
230
231#[allow(unsafe_code)]
235pub fn tanh_act_with_config_and_pool(
236 input: &Tensor,
237 _config: ParallelElementwiseConfig,
238 _thread_pool: Option<&ThreadPool>,
239) -> Tensor {
240 let input_data = input.data();
241 let len = input_data.len();
242 let mut output = AlignedVec::<f32>::uninitialized(len);
245 tanh_slice_dispatch(input_data, &mut output);
246 Tensor::from_raw_parts(input.shape(), input.strides(), output)
247}
248
249const ACTIVATION_PARALLEL_THRESHOLD: usize = 65536;
250const ACTIVATION_CHUNK_SIZE: usize = 8192;
251
252#[allow(unsafe_code)]
258pub fn gelu(input: &Tensor) -> Tensor {
259 let src = input.data();
260 let len = src.len();
261 let mut output = AlignedVec::<f32>::uninitialized(len);
262 if len >= ACTIVATION_PARALLEL_THRESHOLD {
263 output
264 .par_chunks_mut(ACTIVATION_CHUNK_SIZE)
265 .enumerate()
266 .for_each(|(ci, out_chunk)| {
267 let start = ci * ACTIVATION_CHUNK_SIZE;
268 gelu_slice_out(&src[start..start + out_chunk.len()], out_chunk);
269 });
270 } else {
271 gelu_slice_out(src, &mut output);
272 }
273 Tensor::from_raw_parts(input.shape(), input.strides(), output)
274}
275
276pub fn silu(input: &Tensor) -> Tensor {
280 silu_with_config(input, ParallelElementwiseConfig::disabled())
281}
282
283pub fn silu_with_config(input: &Tensor, config: ParallelElementwiseConfig) -> Tensor {
285 silu_with_config_and_pool(input, config, None)
286}
287
288#[allow(unsafe_code)]
292pub fn silu_with_config_and_pool(
293 input: &Tensor,
294 _config: ParallelElementwiseConfig,
295 _thread_pool: Option<&ThreadPool>,
296) -> Tensor {
297 let input_data = input.data();
298 let len = input_data.len();
299 let mut output = AlignedVec::<f32>::uninitialized(len);
302 silu_slice_dispatch(input_data, &mut output);
303 Tensor::from_raw_parts(input.shape(), input.strides(), output)
304}
305
306pub fn mish(input: &Tensor) -> Tensor {
308 let mut output = input.clone();
309 let data = output.data_mut();
310 if data.len() >= ACTIVATION_PARALLEL_THRESHOLD {
311 data.par_chunks_mut(ACTIVATION_CHUNK_SIZE)
312 .for_each(mish_slice);
313 } else {
314 mish_slice(data);
315 }
316 output
317}
318
319fn gelu_slice_out(src: &[f32], dst: &mut [f32]) {
320 for i in 0..src.len() {
321 let x = src[i];
322 let a = 1.702 * x;
323 let ea = (-a).exp();
324 let s = 1.0 / (1.0 + ea);
325 dst[i] = x * s;
326 }
327}
328
329fn mish_slice(data: &mut [f32]) {
330 for i in 0..data.len() {
331 let x = data[i];
332 let sp = (1.0 + x.exp()).ln();
333 data[i] = x * sp.tanh();
334 }
335}
336
337pub fn add_with_config(
339 lhs: &Tensor,
340 rhs: &Tensor,
341 config: ParallelElementwiseConfig,
342) -> Result<Tensor, KernelError> {
343 add_with_config_and_pool(lhs, rhs, config, None)
344}
345
346pub fn add_with_config_and_pool(
347 lhs: &Tensor,
348 rhs: &Tensor,
349 config: ParallelElementwiseConfig,
350 thread_pool: Option<&ThreadPool>,
351) -> Result<Tensor, KernelError> {
352 binary_with_config_and_pool(lhs, rhs, config, thread_pool, BinaryKind::Add)
353}
354
355pub fn sub_with_config(
357 lhs: &Tensor,
358 rhs: &Tensor,
359 config: ParallelElementwiseConfig,
360) -> Result<Tensor, KernelError> {
361 sub_with_config_and_pool(lhs, rhs, config, None)
362}
363
364pub fn sub_with_config_and_pool(
365 lhs: &Tensor,
366 rhs: &Tensor,
367 config: ParallelElementwiseConfig,
368 thread_pool: Option<&ThreadPool>,
369) -> Result<Tensor, KernelError> {
370 binary_with_config_and_pool(lhs, rhs, config, thread_pool, BinaryKind::Sub)
371}
372
373pub fn mul_with_config(
375 lhs: &Tensor,
376 rhs: &Tensor,
377 config: ParallelElementwiseConfig,
378) -> Result<Tensor, KernelError> {
379 mul_with_config_and_pool(lhs, rhs, config, None)
380}
381
382pub fn mul_with_config_and_pool(
383 lhs: &Tensor,
384 rhs: &Tensor,
385 config: ParallelElementwiseConfig,
386 thread_pool: Option<&ThreadPool>,
387) -> Result<Tensor, KernelError> {
388 binary_with_config_and_pool(lhs, rhs, config, thread_pool, BinaryKind::Mul)
389}
390
391fn binary_with_config_and_pool(
392 lhs: &Tensor,
393 rhs: &Tensor,
394 config: ParallelElementwiseConfig,
395 thread_pool: Option<&ThreadPool>,
396 kind: BinaryKind,
397) -> Result<Tensor, KernelError> {
398 if lhs.shape() != rhs.shape() {
399 return binary_fallback(lhs, rhs, kind);
400 }
401
402 let left = lhs.data();
403 let right = rhs.data();
404 let shape = lhs.shape().to_vec();
405 let mut output = AlignedVec::<f32>::uninitialized(left.len());
406
407 if should_parallelize_len(left.len(), config.min_parallel_elements, thread_pool) {
408 let mut work = || {
409 output
410 .par_chunks_mut(PARALLEL_SLICE_CHUNK_ELEMENTS)
411 .enumerate()
412 .for_each(|(chunk_idx, out_chunk)| {
413 let start = chunk_idx * PARALLEL_SLICE_CHUNK_ELEMENTS;
414 let end = start + out_chunk.len();
415 binary_same_shape_dispatch(
416 &left[start..end],
417 &right[start..end],
418 out_chunk,
419 kind,
420 );
421 });
422 };
423
424 if let Some(pool) = thread_pool {
425 pool.install(work);
426 } else {
427 work();
428 }
429 } else {
430 binary_same_shape_dispatch(left, right, &mut output, kind);
431 }
432
433 Tensor::from_aligned(shape, output).map_err(Into::into)
434}
435
436fn binary_fallback(lhs: &Tensor, rhs: &Tensor, kind: BinaryKind) -> Result<Tensor, KernelError> {
437 match kind {
438 BinaryKind::Add => lhs.add(rhs),
439 BinaryKind::Sub => lhs.sub(rhs),
440 BinaryKind::Mul => lhs.mul(rhs),
441 }
442 .map_err(Into::into)
443}