1use super::Backend;
15use crate::device::DeviceCapabilities;
16use crate::dtype::{Float, Numeric, Scalar};
17use rayon::prelude::*;
18use sysinfo::System;
19
20const PARALLEL_THRESHOLD: usize = 4096;
22
23#[derive(Debug, Clone, Copy, Default)]
29pub struct CpuBackend;
30
31impl CpuBackend {
32 #[must_use]
34 pub const fn new() -> Self {
35 Self
36 }
37}
38
39impl Backend for CpuBackend {
44 fn name(&self) -> &'static str {
45 "cpu"
46 }
47
48 fn is_available(&self) -> bool {
49 true }
51
52 fn capabilities(&self) -> DeviceCapabilities {
53 DeviceCapabilities {
54 name: "CPU".to_string(),
55 total_memory: get_system_memory(),
56 available_memory: get_available_memory(),
57 supports_f16: true,
58 supports_f64: true,
59 max_threads_per_block: num_cpus(),
60 compute_capability: None,
61 }
62 }
63
64 fn allocate(&self, size: usize) -> *mut u8 {
65 if size == 0 {
66 return std::ptr::null_mut();
67 }
68 unsafe {
69 let layout = std::alloc::Layout::from_size_align_unchecked(size, 64);
70 std::alloc::alloc(layout)
71 }
72 }
73
74 fn deallocate(&self, ptr: *mut u8, size: usize) {
75 if ptr.is_null() || size == 0 {
76 return;
77 }
78 unsafe {
79 let layout = std::alloc::Layout::from_size_align_unchecked(size, 64);
80 std::alloc::dealloc(ptr, layout);
81 }
82 }
83
84 fn copy_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
85 unsafe {
87 std::ptr::copy_nonoverlapping(src, dst, size);
88 }
89 }
90
91 fn copy_to_host(&self, dst: *mut u8, src: *const u8, size: usize) {
92 unsafe {
94 std::ptr::copy_nonoverlapping(src, dst, size);
95 }
96 }
97
98 fn copy_device_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
99 unsafe {
101 std::ptr::copy_nonoverlapping(src, dst, size);
102 }
103 }
104
105 fn synchronize(&self) {
106 }
108}
109
110fn get_system_memory() -> usize {
116 let sys = System::new_all();
117 sys.total_memory() as usize
118}
119
120fn get_available_memory() -> usize {
122 let sys = System::new_all();
123 sys.available_memory() as usize
124}
125
126fn num_cpus() -> usize {
128 std::thread::available_parallelism()
129 .map(std::num::NonZeroUsize::get)
130 .unwrap_or(1)
131}
132
133impl CpuBackend {
138 pub fn add<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
140 debug_assert_eq!(a.len(), b.len());
141 debug_assert_eq!(a.len(), dst.len());
142
143 if dst.len() >= PARALLEL_THRESHOLD {
144 dst.par_iter_mut()
145 .zip(a.par_iter().zip(b.par_iter()))
146 .for_each(|(d, (a_val, b_val))| {
147 *d = *a_val + *b_val;
148 });
149 } else {
150 for i in 0..dst.len() {
151 dst[i] = a[i] + b[i];
152 }
153 }
154 }
155
156 pub fn sub<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
158 debug_assert_eq!(a.len(), b.len());
159 debug_assert_eq!(a.len(), dst.len());
160
161 if dst.len() >= PARALLEL_THRESHOLD {
162 dst.par_iter_mut()
163 .zip(a.par_iter().zip(b.par_iter()))
164 .for_each(|(d, (a_val, b_val))| {
165 *d = *a_val - *b_val;
166 });
167 } else {
168 for i in 0..dst.len() {
169 dst[i] = a[i] - b[i];
170 }
171 }
172 }
173
174 pub fn mul<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
176 debug_assert_eq!(a.len(), b.len());
177 debug_assert_eq!(a.len(), dst.len());
178
179 if dst.len() >= PARALLEL_THRESHOLD {
180 dst.par_iter_mut()
181 .zip(a.par_iter().zip(b.par_iter()))
182 .for_each(|(d, (a_val, b_val))| {
183 *d = *a_val * *b_val;
184 });
185 } else {
186 for i in 0..dst.len() {
187 dst[i] = a[i] * b[i];
188 }
189 }
190 }
191
192 pub fn div<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
194 debug_assert_eq!(a.len(), b.len());
195 debug_assert_eq!(a.len(), dst.len());
196
197 if dst.len() >= PARALLEL_THRESHOLD {
198 dst.par_iter_mut()
199 .zip(a.par_iter().zip(b.par_iter()))
200 .for_each(|(d, (a_val, b_val))| {
201 *d = *a_val / *b_val;
202 });
203 } else {
204 for i in 0..dst.len() {
205 dst[i] = a[i] / b[i];
206 }
207 }
208 }
209
210 pub fn add_scalar<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], scalar: T) {
212 debug_assert_eq!(a.len(), dst.len());
213
214 if dst.len() >= PARALLEL_THRESHOLD {
215 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
216 *d = *a_val + scalar;
217 });
218 } else {
219 for i in 0..dst.len() {
220 dst[i] = a[i] + scalar;
221 }
222 }
223 }
224
225 pub fn mul_scalar<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], scalar: T) {
227 debug_assert_eq!(a.len(), dst.len());
228
229 if dst.len() >= PARALLEL_THRESHOLD {
230 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
231 *d = *a_val * scalar;
232 });
233 } else {
234 for i in 0..dst.len() {
235 dst[i] = a[i] * scalar;
236 }
237 }
238 }
239
240 pub fn neg<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T]) {
242 debug_assert_eq!(a.len(), dst.len());
243
244 if dst.len() >= PARALLEL_THRESHOLD {
245 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
246 *d = T::zero() - *a_val;
247 });
248 } else {
249 for i in 0..dst.len() {
250 dst[i] = T::zero() - a[i];
251 }
252 }
253 }
254
255 pub fn abs<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T]) {
257 debug_assert_eq!(a.len(), dst.len());
258
259 if dst.len() >= PARALLEL_THRESHOLD {
260 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
261 *d = if *a_val < T::zero() {
262 T::zero() - *a_val
263 } else {
264 *a_val
265 };
266 });
267 } else {
268 for i in 0..dst.len() {
269 dst[i] = if a[i] < T::zero() {
270 T::zero() - a[i]
271 } else {
272 a[i]
273 };
274 }
275 }
276 }
277}
278
279impl CpuBackend {
284 pub fn relu<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
286 debug_assert_eq!(a.len(), dst.len());
287
288 if dst.len() >= PARALLEL_THRESHOLD {
289 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
290 *d = if *a_val > T::zero() {
291 *a_val
292 } else {
293 T::zero()
294 };
295 });
296 } else {
297 for i in 0..dst.len() {
298 dst[i] = if a[i] > T::zero() { a[i] } else { T::zero() };
299 }
300 }
301 }
302
303 pub fn sigmoid<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
305 debug_assert_eq!(a.len(), dst.len());
306
307 if dst.len() >= PARALLEL_THRESHOLD {
308 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
309 *d = T::one() / (T::one() + (-*a_val).exp_value());
310 });
311 } else {
312 for i in 0..dst.len() {
313 dst[i] = T::one() / (T::one() + (-a[i]).exp_value());
314 }
315 }
316 }
317
318 pub fn tanh<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
320 debug_assert_eq!(a.len(), dst.len());
321
322 if dst.len() >= PARALLEL_THRESHOLD {
323 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
324 *d = a_val.tanh_value();
325 });
326 } else {
327 for i in 0..dst.len() {
328 dst[i] = a[i].tanh_value();
329 }
330 }
331 }
332
333 pub fn exp<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
335 debug_assert_eq!(a.len(), dst.len());
336
337 if dst.len() >= PARALLEL_THRESHOLD {
338 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
339 *d = a_val.exp_value();
340 });
341 } else {
342 for i in 0..dst.len() {
343 dst[i] = a[i].exp_value();
344 }
345 }
346 }
347
348 pub fn ln<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
350 debug_assert_eq!(a.len(), dst.len());
351
352 if dst.len() >= PARALLEL_THRESHOLD {
353 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
354 *d = a_val.ln_value();
355 });
356 } else {
357 for i in 0..dst.len() {
358 dst[i] = a[i].ln_value();
359 }
360 }
361 }
362
363 pub fn sqrt<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
365 debug_assert_eq!(a.len(), dst.len());
366
367 if dst.len() >= PARALLEL_THRESHOLD {
368 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
369 *d = a_val.sqrt_value();
370 });
371 } else {
372 for i in 0..dst.len() {
373 dst[i] = a[i].sqrt_value();
374 }
375 }
376 }
377
378 pub fn square<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T]) {
380 debug_assert_eq!(a.len(), dst.len());
381
382 if dst.len() >= PARALLEL_THRESHOLD {
383 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
384 *d = *a_val * *a_val;
385 });
386 } else {
387 for i in 0..dst.len() {
388 dst[i] = a[i] * a[i];
389 }
390 }
391 }
392}
393
394impl CpuBackend {
399 pub fn sum<T: Numeric>(a: &[T]) -> T {
401 let mut result = T::zero();
402 for &val in a {
403 result = result + val;
404 }
405 result
406 }
407
408 pub fn prod<T: Numeric>(a: &[T]) -> T {
410 let mut result = T::one();
411 for &val in a {
412 result = result * val;
413 }
414 result
415 }
416
417 pub fn max<T: Numeric>(a: &[T]) -> Option<T> {
419 if a.is_empty() {
420 return None;
421 }
422
423 let mut result = a[0];
424 for &val in &a[1..] {
425 if val > result {
426 result = val;
427 }
428 }
429 Some(result)
430 }
431
432 pub fn min<T: Numeric>(a: &[T]) -> Option<T> {
434 if a.is_empty() {
435 return None;
436 }
437
438 let mut result = a[0];
439 for &val in &a[1..] {
440 if val < result {
441 result = val;
442 }
443 }
444 Some(result)
445 }
446
447 pub fn mean<T: Float>(a: &[T]) -> Option<T> {
449 if a.is_empty() {
450 return None;
451 }
452
453 let sum = Self::sum(a);
454 let len = T::from(a.len()).unwrap_or(T::one());
455 Some(sum / len)
456 }
457
458 pub fn argmax<T: Numeric>(a: &[T]) -> Option<usize> {
460 if a.is_empty() {
461 return None;
462 }
463
464 let mut max_idx = 0;
465 let mut max_val = a[0];
466 for (i, &val) in a.iter().enumerate().skip(1) {
467 if val > max_val {
468 max_val = val;
469 max_idx = i;
470 }
471 }
472 Some(max_idx)
473 }
474
475 pub fn argmin<T: Numeric>(a: &[T]) -> Option<usize> {
477 if a.is_empty() {
478 return None;
479 }
480
481 let mut min_idx = 0;
482 let mut min_val = a[0];
483 for (i, &val) in a.iter().enumerate().skip(1) {
484 if val < min_val {
485 min_val = val;
486 min_idx = i;
487 }
488 }
489 Some(min_idx)
490 }
491}
492
493impl CpuBackend {
498 pub fn matmul<T: Numeric>(c: &mut [T], a: &[T], b: &[T], m: usize, n: usize, k: usize) {
504 debug_assert_eq!(a.len(), m * k);
505 debug_assert_eq!(b.len(), k * n);
506 debug_assert_eq!(c.len(), m * n);
507
508 use std::any::TypeId;
510 if TypeId::of::<T>() == TypeId::of::<f32>() {
511 unsafe {
513 let a_f32: &[f32] = &*(a as *const [T] as *const [f32]);
514 let b_f32: &[f32] = &*(b as *const [T] as *const [f32]);
515 let c_f32: &mut [f32] = &mut *(c as *mut [T] as *mut [f32]);
516 Self::matmul_f32(c_f32, a_f32, b_f32, m, n, k);
517 }
518 return;
519 }
520
521 if TypeId::of::<T>() == TypeId::of::<f64>() {
522 unsafe {
524 let a_f64: &[f64] = &*(a as *const [T] as *const [f64]);
525 let b_f64: &[f64] = &*(b as *const [T] as *const [f64]);
526 let c_f64: &mut [f64] = &mut *(c as *mut [T] as *mut [f64]);
527 Self::matmul_f64(c_f64, a_f64, b_f64, m, n, k);
528 }
529 return;
530 }
531
532 const BLOCK_SIZE: usize = 64;
535
536 for val in c.iter_mut() {
538 *val = T::zero();
539 }
540
541 for i0 in (0..m).step_by(BLOCK_SIZE) {
543 let i_end = (i0 + BLOCK_SIZE).min(m);
544 for p0 in (0..k).step_by(BLOCK_SIZE) {
545 let p_end = (p0 + BLOCK_SIZE).min(k);
546 for j0 in (0..n).step_by(BLOCK_SIZE) {
547 let j_end = (j0 + BLOCK_SIZE).min(n);
548
549 for i in i0..i_end {
551 for p in p0..p_end {
552 let a_val = a[i * k + p];
553 for j in j0..j_end {
554 c[i * n + j] = c[i * n + j] + a_val * b[p * n + j];
555 }
556 }
557 }
558 }
559 }
560 }
561 }
562
563 pub fn sgemm(
567 c: &mut [f32],
568 a: &[f32],
569 b: &[f32],
570 m: usize,
571 n: usize,
572 k: usize,
573 alpha: f32,
574 beta: f32,
575 ) {
576 debug_assert_eq!(a.len(), m * k);
577 debug_assert_eq!(b.len(), k * n);
578 debug_assert_eq!(c.len(), m * n);
579
580 unsafe {
581 matrixmultiply::sgemm(
582 m,
583 k,
584 n,
585 alpha,
586 a.as_ptr(),
587 k as isize,
588 1, b.as_ptr(),
590 n as isize,
591 1, beta,
593 c.as_mut_ptr(),
594 n as isize,
595 1, );
597 }
598 }
599
600 pub fn dgemm(
604 c: &mut [f64],
605 a: &[f64],
606 b: &[f64],
607 m: usize,
608 n: usize,
609 k: usize,
610 alpha: f64,
611 beta: f64,
612 ) {
613 debug_assert_eq!(a.len(), m * k);
614 debug_assert_eq!(b.len(), k * n);
615 debug_assert_eq!(c.len(), m * n);
616
617 unsafe {
618 matrixmultiply::dgemm(
619 m,
620 k,
621 n,
622 alpha,
623 a.as_ptr(),
624 k as isize,
625 1, b.as_ptr(),
627 n as isize,
628 1, beta,
630 c.as_mut_ptr(),
631 n as isize,
632 1, );
634 }
635 }
636
637 pub fn matmul_f32(c: &mut [f32], a: &[f32], b: &[f32], m: usize, n: usize, k: usize) {
639 Self::sgemm(c, a, b, m, n, k, 1.0, 0.0);
640 }
641
642 pub fn matmul_f64(c: &mut [f64], a: &[f64], b: &[f64], m: usize, n: usize, k: usize) {
644 Self::dgemm(c, a, b, m, n, k, 1.0, 0.0);
645 }
646
647 pub fn transpose<T: Scalar>(dst: &mut [T], src: &[T], rows: usize, cols: usize) {
651 debug_assert_eq!(src.len(), rows * cols);
652 debug_assert_eq!(dst.len(), rows * cols);
653
654 for i in 0..rows {
655 for j in 0..cols {
656 dst[j * rows + i] = src[i * cols + j];
657 }
658 }
659 }
660
661 pub fn dot<T: Numeric>(a: &[T], b: &[T]) -> T {
663 debug_assert_eq!(a.len(), b.len());
664
665 let mut sum = T::zero();
666 for i in 0..a.len() {
667 sum = sum + a[i] * b[i];
668 }
669 sum
670 }
671}
672
673impl CpuBackend {
678 pub fn eq<T: Scalar + PartialEq>(dst: &mut [bool], a: &[T], b: &[T]) {
680 debug_assert_eq!(a.len(), b.len());
681 debug_assert_eq!(a.len(), dst.len());
682
683 for i in 0..dst.len() {
684 dst[i] = a[i] == b[i];
685 }
686 }
687
688 pub fn lt<T: Numeric>(dst: &mut [bool], a: &[T], b: &[T]) {
690 debug_assert_eq!(a.len(), b.len());
691 debug_assert_eq!(a.len(), dst.len());
692
693 for i in 0..dst.len() {
694 dst[i] = a[i] < b[i];
695 }
696 }
697
698 pub fn gt<T: Numeric>(dst: &mut [bool], a: &[T], b: &[T]) {
700 debug_assert_eq!(a.len(), b.len());
701 debug_assert_eq!(a.len(), dst.len());
702
703 for i in 0..dst.len() {
704 dst[i] = a[i] > b[i];
705 }
706 }
707}
708
709impl CpuBackend {
714 pub fn fill<T: Scalar>(dst: &mut [T], value: T) {
716 for elem in dst.iter_mut() {
717 *elem = value;
718 }
719 }
720
721 pub fn fill_zeros<T: Scalar>(dst: &mut [T]) {
723 Self::fill(dst, T::zeroed());
724 }
725
726 pub fn copy<T: Scalar>(dst: &mut [T], src: &[T]) {
728 debug_assert_eq!(dst.len(), src.len());
729 dst.copy_from_slice(src);
730 }
731}
732
733#[cfg(test)]
738mod tests {
739 use super::*;
740
741 #[test]
742 fn test_add() {
743 let a = [1.0_f32, 2.0, 3.0];
744 let b = [4.0_f32, 5.0, 6.0];
745 let mut c = [0.0_f32; 3];
746
747 CpuBackend::add(&mut c, &a, &b);
748 assert_eq!(c, [5.0, 7.0, 9.0]);
749 }
750
751 #[test]
752 fn test_mul() {
753 let a = [2.0_f32, 3.0, 4.0];
754 let b = [2.0_f32, 2.0, 2.0];
755 let mut c = [0.0_f32; 3];
756
757 CpuBackend::mul(&mut c, &a, &b);
758 assert_eq!(c, [4.0, 6.0, 8.0]);
759 }
760
761 #[test]
762 fn test_relu() {
763 let a = [-1.0_f32, 0.0, 1.0, 2.0];
764 let mut b = [0.0_f32; 4];
765
766 CpuBackend::relu(&mut b, &a);
767 assert_eq!(b, [0.0, 0.0, 1.0, 2.0]);
768 }
769
770 #[test]
771 fn test_sum() {
772 let a = [1.0_f32, 2.0, 3.0, 4.0];
773 assert_eq!(CpuBackend::sum(&a), 10.0);
774 }
775
776 #[test]
777 fn test_max_min() {
778 let a = [1.0_f32, 4.0, 2.0, 3.0];
779 assert_eq!(CpuBackend::max(&a), Some(4.0));
780 assert_eq!(CpuBackend::min(&a), Some(1.0));
781 }
782
783 #[test]
784 fn test_argmax() {
785 let a = [1.0_f32, 4.0, 2.0, 3.0];
786 assert_eq!(CpuBackend::argmax(&a), Some(1));
787 }
788
789 #[test]
790 fn test_matmul() {
791 let a = [1.0_f32, 2.0, 3.0, 4.0];
795 let b = [5.0_f32, 6.0, 7.0, 8.0];
796 let mut c = [0.0_f32; 4];
797
798 CpuBackend::matmul(&mut c, &a, &b, 2, 2, 2);
799 assert_eq!(c, [19.0, 22.0, 43.0, 50.0]);
800 }
801
802 #[test]
803 fn test_transpose() {
804 let a = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
807 let mut b = [0.0_f32; 6];
808
809 CpuBackend::transpose(&mut b, &a, 2, 3);
810 assert_eq!(b, [1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
811 }
812
813 #[test]
814 fn test_dot() {
815 let a = [1.0_f32, 2.0, 3.0];
816 let b = [4.0_f32, 5.0, 6.0];
817 assert_eq!(CpuBackend::dot(&a, &b), 32.0);
818 }
819
820 #[test]
821 fn test_fill() {
822 let mut a = [0.0_f32; 5];
823 CpuBackend::fill(&mut a, 42.0);
824 assert_eq!(a, [42.0; 5]);
825 }
826}