1use super::Backend;
18use crate::device::DeviceCapabilities;
19use crate::dtype::{Float, Numeric, Scalar};
20use rayon::prelude::*;
21use sysinfo::System;
22
23const PARALLEL_THRESHOLD: usize = 4096;
25
26#[derive(Debug, Clone, Copy, Default)]
32pub struct CpuBackend;
33
34impl CpuBackend {
35 #[must_use]
37 pub const fn new() -> Self {
38 Self
39 }
40}
41
42impl Backend for CpuBackend {
47 fn name(&self) -> &'static str {
48 "cpu"
49 }
50
51 fn is_available(&self) -> bool {
52 true }
54
55 fn capabilities(&self) -> DeviceCapabilities {
56 DeviceCapabilities {
57 name: "CPU".to_string(),
58 total_memory: get_system_memory(),
59 available_memory: get_available_memory(),
60 supports_f16: true,
61 supports_f64: true,
62 max_threads_per_block: num_cpus(),
63 compute_capability: None,
64 }
65 }
66
67 fn allocate(&self, size: usize) -> *mut u8 {
68 if size == 0 {
69 return std::ptr::null_mut();
70 }
71 unsafe {
72 let layout = std::alloc::Layout::from_size_align_unchecked(size, 64);
73 std::alloc::alloc(layout)
74 }
75 }
76
77 fn deallocate(&self, ptr: *mut u8, size: usize) {
78 if ptr.is_null() || size == 0 {
79 return;
80 }
81 unsafe {
82 let layout = std::alloc::Layout::from_size_align_unchecked(size, 64);
83 std::alloc::dealloc(ptr, layout);
84 }
85 }
86
87 fn copy_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
88 unsafe {
90 std::ptr::copy_nonoverlapping(src, dst, size);
91 }
92 }
93
94 fn copy_to_host(&self, dst: *mut u8, src: *const u8, size: usize) {
95 unsafe {
97 std::ptr::copy_nonoverlapping(src, dst, size);
98 }
99 }
100
101 fn copy_device_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
102 unsafe {
104 std::ptr::copy_nonoverlapping(src, dst, size);
105 }
106 }
107
108 fn synchronize(&self) {
109 }
111}
112
113fn get_system_memory() -> usize {
119 let sys = System::new_all();
120 sys.total_memory() as usize
121}
122
123fn get_available_memory() -> usize {
125 let sys = System::new_all();
126 sys.available_memory() as usize
127}
128
129fn num_cpus() -> usize {
131 std::thread::available_parallelism()
132 .map(std::num::NonZeroUsize::get)
133 .unwrap_or(1)
134}
135
136impl CpuBackend {
141 pub fn add<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
143 debug_assert_eq!(a.len(), b.len());
144 debug_assert_eq!(a.len(), dst.len());
145
146 if dst.len() >= PARALLEL_THRESHOLD {
147 dst.par_iter_mut()
148 .zip(a.par_iter().zip(b.par_iter()))
149 .for_each(|(d, (a_val, b_val))| {
150 *d = *a_val + *b_val;
151 });
152 } else {
153 for i in 0..dst.len() {
154 dst[i] = a[i] + b[i];
155 }
156 }
157 }
158
159 pub fn sub<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
161 debug_assert_eq!(a.len(), b.len());
162 debug_assert_eq!(a.len(), dst.len());
163
164 if dst.len() >= PARALLEL_THRESHOLD {
165 dst.par_iter_mut()
166 .zip(a.par_iter().zip(b.par_iter()))
167 .for_each(|(d, (a_val, b_val))| {
168 *d = *a_val - *b_val;
169 });
170 } else {
171 for i in 0..dst.len() {
172 dst[i] = a[i] - b[i];
173 }
174 }
175 }
176
177 pub fn mul<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
179 debug_assert_eq!(a.len(), b.len());
180 debug_assert_eq!(a.len(), dst.len());
181
182 if dst.len() >= PARALLEL_THRESHOLD {
183 dst.par_iter_mut()
184 .zip(a.par_iter().zip(b.par_iter()))
185 .for_each(|(d, (a_val, b_val))| {
186 *d = *a_val * *b_val;
187 });
188 } else {
189 for i in 0..dst.len() {
190 dst[i] = a[i] * b[i];
191 }
192 }
193 }
194
195 pub fn div<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
197 debug_assert_eq!(a.len(), b.len());
198 debug_assert_eq!(a.len(), dst.len());
199
200 if dst.len() >= PARALLEL_THRESHOLD {
201 dst.par_iter_mut()
202 .zip(a.par_iter().zip(b.par_iter()))
203 .for_each(|(d, (a_val, b_val))| {
204 *d = *a_val / *b_val;
205 });
206 } else {
207 for i in 0..dst.len() {
208 dst[i] = a[i] / b[i];
209 }
210 }
211 }
212
213 pub fn add_scalar<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], scalar: T) {
215 debug_assert_eq!(a.len(), dst.len());
216
217 if dst.len() >= PARALLEL_THRESHOLD {
218 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
219 *d = *a_val + scalar;
220 });
221 } else {
222 for i in 0..dst.len() {
223 dst[i] = a[i] + scalar;
224 }
225 }
226 }
227
228 pub fn mul_scalar<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], scalar: T) {
230 debug_assert_eq!(a.len(), dst.len());
231
232 if dst.len() >= PARALLEL_THRESHOLD {
233 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
234 *d = *a_val * scalar;
235 });
236 } else {
237 for i in 0..dst.len() {
238 dst[i] = a[i] * scalar;
239 }
240 }
241 }
242
243 pub fn neg<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T]) {
245 debug_assert_eq!(a.len(), dst.len());
246
247 if dst.len() >= PARALLEL_THRESHOLD {
248 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
249 *d = T::zero() - *a_val;
250 });
251 } else {
252 for i in 0..dst.len() {
253 dst[i] = T::zero() - a[i];
254 }
255 }
256 }
257
258 pub fn abs<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T]) {
260 debug_assert_eq!(a.len(), dst.len());
261
262 if dst.len() >= PARALLEL_THRESHOLD {
263 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
264 *d = if *a_val < T::zero() {
265 T::zero() - *a_val
266 } else {
267 *a_val
268 };
269 });
270 } else {
271 for i in 0..dst.len() {
272 dst[i] = if a[i] < T::zero() {
273 T::zero() - a[i]
274 } else {
275 a[i]
276 };
277 }
278 }
279 }
280}
281
282impl CpuBackend {
287 pub fn relu<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
289 debug_assert_eq!(a.len(), dst.len());
290
291 if dst.len() >= PARALLEL_THRESHOLD {
292 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
293 *d = if *a_val > T::zero() {
294 *a_val
295 } else {
296 T::zero()
297 };
298 });
299 } else {
300 for i in 0..dst.len() {
301 dst[i] = if a[i] > T::zero() { a[i] } else { T::zero() };
302 }
303 }
304 }
305
306 pub fn sigmoid<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
308 debug_assert_eq!(a.len(), dst.len());
309
310 if dst.len() >= PARALLEL_THRESHOLD {
311 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
312 *d = T::one() / (T::one() + (-*a_val).exp_value());
313 });
314 } else {
315 for i in 0..dst.len() {
316 dst[i] = T::one() / (T::one() + (-a[i]).exp_value());
317 }
318 }
319 }
320
321 pub fn tanh<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
323 debug_assert_eq!(a.len(), dst.len());
324
325 if dst.len() >= PARALLEL_THRESHOLD {
326 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
327 *d = a_val.tanh_value();
328 });
329 } else {
330 for i in 0..dst.len() {
331 dst[i] = a[i].tanh_value();
332 }
333 }
334 }
335
336 pub fn exp<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
338 debug_assert_eq!(a.len(), dst.len());
339
340 if dst.len() >= PARALLEL_THRESHOLD {
341 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
342 *d = a_val.exp_value();
343 });
344 } else {
345 for i in 0..dst.len() {
346 dst[i] = a[i].exp_value();
347 }
348 }
349 }
350
351 pub fn ln<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
353 debug_assert_eq!(a.len(), dst.len());
354
355 if dst.len() >= PARALLEL_THRESHOLD {
356 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
357 *d = a_val.ln_value();
358 });
359 } else {
360 for i in 0..dst.len() {
361 dst[i] = a[i].ln_value();
362 }
363 }
364 }
365
366 pub fn sqrt<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
368 debug_assert_eq!(a.len(), dst.len());
369
370 if dst.len() >= PARALLEL_THRESHOLD {
371 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
372 *d = a_val.sqrt_value();
373 });
374 } else {
375 for i in 0..dst.len() {
376 dst[i] = a[i].sqrt_value();
377 }
378 }
379 }
380
381 pub fn square<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T]) {
383 debug_assert_eq!(a.len(), dst.len());
384
385 if dst.len() >= PARALLEL_THRESHOLD {
386 dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
387 *d = *a_val * *a_val;
388 });
389 } else {
390 for i in 0..dst.len() {
391 dst[i] = a[i] * a[i];
392 }
393 }
394 }
395}
396
397impl CpuBackend {
402 pub fn sum<T: Numeric>(a: &[T]) -> T {
404 let mut result = T::zero();
405 for &val in a {
406 result = result + val;
407 }
408 result
409 }
410
411 pub fn prod<T: Numeric>(a: &[T]) -> T {
413 let mut result = T::one();
414 for &val in a {
415 result = result * val;
416 }
417 result
418 }
419
420 pub fn max<T: Numeric>(a: &[T]) -> Option<T> {
422 if a.is_empty() {
423 return None;
424 }
425
426 let mut result = a[0];
427 for &val in &a[1..] {
428 if val > result {
429 result = val;
430 }
431 }
432 Some(result)
433 }
434
435 pub fn min<T: Numeric>(a: &[T]) -> Option<T> {
437 if a.is_empty() {
438 return None;
439 }
440
441 let mut result = a[0];
442 for &val in &a[1..] {
443 if val < result {
444 result = val;
445 }
446 }
447 Some(result)
448 }
449
450 pub fn mean<T: Float>(a: &[T]) -> Option<T> {
452 if a.is_empty() {
453 return None;
454 }
455
456 let sum = Self::sum(a);
457 let len = T::from(a.len()).unwrap_or(T::one());
458 Some(sum / len)
459 }
460
461 pub fn argmax<T: Numeric>(a: &[T]) -> Option<usize> {
463 if a.is_empty() {
464 return None;
465 }
466
467 let mut max_idx = 0;
468 let mut max_val = a[0];
469 for (i, &val) in a.iter().enumerate().skip(1) {
470 if val > max_val {
471 max_val = val;
472 max_idx = i;
473 }
474 }
475 Some(max_idx)
476 }
477
478 pub fn argmin<T: Numeric>(a: &[T]) -> Option<usize> {
480 if a.is_empty() {
481 return None;
482 }
483
484 let mut min_idx = 0;
485 let mut min_val = a[0];
486 for (i, &val) in a.iter().enumerate().skip(1) {
487 if val < min_val {
488 min_val = val;
489 min_idx = i;
490 }
491 }
492 Some(min_idx)
493 }
494}
495
496impl CpuBackend {
501 pub fn matmul<T: Numeric>(c: &mut [T], a: &[T], b: &[T], m: usize, n: usize, k: usize) {
507 debug_assert_eq!(a.len(), m * k);
508 debug_assert_eq!(b.len(), k * n);
509 debug_assert_eq!(c.len(), m * n);
510
511 use std::any::TypeId;
513 if TypeId::of::<T>() == TypeId::of::<f32>() {
514 unsafe {
516 let a_f32: &[f32] = &*(std::ptr::from_ref::<[T]>(a) as *const [f32]);
517 let b_f32: &[f32] = &*(std::ptr::from_ref::<[T]>(b) as *const [f32]);
518 let c_f32: &mut [f32] = &mut *(std::ptr::from_mut::<[T]>(c) as *mut [f32]);
519 Self::matmul_f32(c_f32, a_f32, b_f32, m, n, k);
520 }
521 return;
522 }
523
524 if TypeId::of::<T>() == TypeId::of::<f64>() {
525 unsafe {
527 let a_f64: &[f64] = &*(std::ptr::from_ref::<[T]>(a) as *const [f64]);
528 let b_f64: &[f64] = &*(std::ptr::from_ref::<[T]>(b) as *const [f64]);
529 let c_f64: &mut [f64] = &mut *(std::ptr::from_mut::<[T]>(c) as *mut [f64]);
530 Self::matmul_f64(c_f64, a_f64, b_f64, m, n, k);
531 }
532 return;
533 }
534
535 const BLOCK_SIZE: usize = 64;
538
539 for val in c.iter_mut() {
541 *val = T::zero();
542 }
543
544 for i0 in (0..m).step_by(BLOCK_SIZE) {
546 let i_end = (i0 + BLOCK_SIZE).min(m);
547 for p0 in (0..k).step_by(BLOCK_SIZE) {
548 let p_end = (p0 + BLOCK_SIZE).min(k);
549 for j0 in (0..n).step_by(BLOCK_SIZE) {
550 let j_end = (j0 + BLOCK_SIZE).min(n);
551
552 for i in i0..i_end {
554 for p in p0..p_end {
555 let a_val = a[i * k + p];
556 for j in j0..j_end {
557 c[i * n + j] = c[i * n + j] + a_val * b[p * n + j];
558 }
559 }
560 }
561 }
562 }
563 }
564 }
565
566 pub fn sgemm(
570 c: &mut [f32],
571 a: &[f32],
572 b: &[f32],
573 m: usize,
574 n: usize,
575 k: usize,
576 alpha: f32,
577 beta: f32,
578 ) {
579 debug_assert_eq!(a.len(), m * k);
580 debug_assert_eq!(b.len(), k * n);
581 debug_assert_eq!(c.len(), m * n);
582
583 unsafe {
584 matrixmultiply::sgemm(
585 m,
586 k,
587 n,
588 alpha,
589 a.as_ptr(),
590 k as isize,
591 1, b.as_ptr(),
593 n as isize,
594 1, beta,
596 c.as_mut_ptr(),
597 n as isize,
598 1, );
600 }
601 }
602
603 pub fn dgemm(
607 c: &mut [f64],
608 a: &[f64],
609 b: &[f64],
610 m: usize,
611 n: usize,
612 k: usize,
613 alpha: f64,
614 beta: f64,
615 ) {
616 debug_assert_eq!(a.len(), m * k);
617 debug_assert_eq!(b.len(), k * n);
618 debug_assert_eq!(c.len(), m * n);
619
620 unsafe {
621 matrixmultiply::dgemm(
622 m,
623 k,
624 n,
625 alpha,
626 a.as_ptr(),
627 k as isize,
628 1, b.as_ptr(),
630 n as isize,
631 1, beta,
633 c.as_mut_ptr(),
634 n as isize,
635 1, );
637 }
638 }
639
640 pub fn matmul_f32(c: &mut [f32], a: &[f32], b: &[f32], m: usize, n: usize, k: usize) {
642 Self::sgemm(c, a, b, m, n, k, 1.0, 0.0);
643 }
644
645 pub fn matmul_f64(c: &mut [f64], a: &[f64], b: &[f64], m: usize, n: usize, k: usize) {
647 Self::dgemm(c, a, b, m, n, k, 1.0, 0.0);
648 }
649
650 pub fn transpose<T: Scalar>(dst: &mut [T], src: &[T], rows: usize, cols: usize) {
654 debug_assert_eq!(src.len(), rows * cols);
655 debug_assert_eq!(dst.len(), rows * cols);
656
657 for i in 0..rows {
658 for j in 0..cols {
659 dst[j * rows + i] = src[i * cols + j];
660 }
661 }
662 }
663
664 pub fn dot<T: Numeric>(a: &[T], b: &[T]) -> T {
666 debug_assert_eq!(a.len(), b.len());
667
668 let mut sum = T::zero();
669 for i in 0..a.len() {
670 sum = sum + a[i] * b[i];
671 }
672 sum
673 }
674}
675
676impl CpuBackend {
681 pub fn eq<T: Scalar + PartialEq>(dst: &mut [bool], a: &[T], b: &[T]) {
683 debug_assert_eq!(a.len(), b.len());
684 debug_assert_eq!(a.len(), dst.len());
685
686 for i in 0..dst.len() {
687 dst[i] = a[i] == b[i];
688 }
689 }
690
691 pub fn lt<T: Numeric>(dst: &mut [bool], a: &[T], b: &[T]) {
693 debug_assert_eq!(a.len(), b.len());
694 debug_assert_eq!(a.len(), dst.len());
695
696 for i in 0..dst.len() {
697 dst[i] = a[i] < b[i];
698 }
699 }
700
701 pub fn gt<T: Numeric>(dst: &mut [bool], a: &[T], b: &[T]) {
703 debug_assert_eq!(a.len(), b.len());
704 debug_assert_eq!(a.len(), dst.len());
705
706 for i in 0..dst.len() {
707 dst[i] = a[i] > b[i];
708 }
709 }
710}
711
712impl CpuBackend {
717 pub fn fill<T: Scalar>(dst: &mut [T], value: T) {
719 for elem in dst.iter_mut() {
720 *elem = value;
721 }
722 }
723
724 pub fn fill_zeros<T: Scalar>(dst: &mut [T]) {
726 Self::fill(dst, T::zeroed());
727 }
728
729 pub fn copy<T: Scalar>(dst: &mut [T], src: &[T]) {
731 debug_assert_eq!(dst.len(), src.len());
732 dst.copy_from_slice(src);
733 }
734}
735
736#[cfg(test)]
741mod tests {
742 use super::*;
743
744 #[test]
745 fn test_add() {
746 let a = [1.0_f32, 2.0, 3.0];
747 let b = [4.0_f32, 5.0, 6.0];
748 let mut c = [0.0_f32; 3];
749
750 CpuBackend::add(&mut c, &a, &b);
751 assert_eq!(c, [5.0, 7.0, 9.0]);
752 }
753
754 #[test]
755 fn test_mul() {
756 let a = [2.0_f32, 3.0, 4.0];
757 let b = [2.0_f32, 2.0, 2.0];
758 let mut c = [0.0_f32; 3];
759
760 CpuBackend::mul(&mut c, &a, &b);
761 assert_eq!(c, [4.0, 6.0, 8.0]);
762 }
763
764 #[test]
765 fn test_relu() {
766 let a = [-1.0_f32, 0.0, 1.0, 2.0];
767 let mut b = [0.0_f32; 4];
768
769 CpuBackend::relu(&mut b, &a);
770 assert_eq!(b, [0.0, 0.0, 1.0, 2.0]);
771 }
772
773 #[test]
774 fn test_sum() {
775 let a = [1.0_f32, 2.0, 3.0, 4.0];
776 assert_eq!(CpuBackend::sum(&a), 10.0);
777 }
778
779 #[test]
780 fn test_max_min() {
781 let a = [1.0_f32, 4.0, 2.0, 3.0];
782 assert_eq!(CpuBackend::max(&a), Some(4.0));
783 assert_eq!(CpuBackend::min(&a), Some(1.0));
784 }
785
786 #[test]
787 fn test_argmax() {
788 let a = [1.0_f32, 4.0, 2.0, 3.0];
789 assert_eq!(CpuBackend::argmax(&a), Some(1));
790 }
791
792 #[test]
793 fn test_matmul() {
794 let a = [1.0_f32, 2.0, 3.0, 4.0];
798 let b = [5.0_f32, 6.0, 7.0, 8.0];
799 let mut c = [0.0_f32; 4];
800
801 CpuBackend::matmul(&mut c, &a, &b, 2, 2, 2);
802 assert_eq!(c, [19.0, 22.0, 43.0, 50.0]);
803 }
804
805 #[test]
806 fn test_transpose() {
807 let a = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
810 let mut b = [0.0_f32; 6];
811
812 CpuBackend::transpose(&mut b, &a, 2, 3);
813 assert_eq!(b, [1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
814 }
815
816 #[test]
817 fn test_dot() {
818 let a = [1.0_f32, 2.0, 3.0];
819 let b = [4.0_f32, 5.0, 6.0];
820 assert_eq!(CpuBackend::dot(&a, &b), 32.0);
821 }
822
823 #[test]
824 fn test_fill() {
825 let mut a = [0.0_f32; 5];
826 CpuBackend::fill(&mut a, 42.0);
827 assert_eq!(a, [42.0; 5]);
828 }
829}