1use super::Backend;
15use crate::device::DeviceCapabilities;
16use crate::dtype::{Float, Numeric, Scalar};
17use sysinfo::System;
18
19#[derive(Debug, Clone, Copy, Default)]
25pub struct CpuBackend;
26
27impl CpuBackend {
28 #[must_use]
30 pub const fn new() -> Self {
31 Self
32 }
33}
34
35impl Backend for CpuBackend {
40 fn name(&self) -> &'static str {
41 "cpu"
42 }
43
44 fn is_available(&self) -> bool {
45 true }
47
48 fn capabilities(&self) -> DeviceCapabilities {
49 DeviceCapabilities {
50 name: "CPU".to_string(),
51 total_memory: get_system_memory(),
52 available_memory: get_available_memory(),
53 supports_f16: true,
54 supports_f64: true,
55 max_threads_per_block: num_cpus(),
56 compute_capability: None,
57 }
58 }
59
60 fn allocate(&self, size: usize) -> *mut u8 {
61 if size == 0 {
62 return std::ptr::null_mut();
63 }
64 unsafe {
65 let layout = std::alloc::Layout::from_size_align_unchecked(size, 64);
66 std::alloc::alloc(layout)
67 }
68 }
69
70 fn deallocate(&self, ptr: *mut u8, size: usize) {
71 if ptr.is_null() || size == 0 {
72 return;
73 }
74 unsafe {
75 let layout = std::alloc::Layout::from_size_align_unchecked(size, 64);
76 std::alloc::dealloc(ptr, layout);
77 }
78 }
79
80 fn copy_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
81 unsafe {
83 std::ptr::copy_nonoverlapping(src, dst, size);
84 }
85 }
86
87 fn copy_to_host(&self, dst: *mut u8, src: *const u8, size: usize) {
88 unsafe {
90 std::ptr::copy_nonoverlapping(src, dst, size);
91 }
92 }
93
94 fn copy_device_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
95 unsafe {
97 std::ptr::copy_nonoverlapping(src, dst, size);
98 }
99 }
100
101 fn synchronize(&self) {
102 }
104}
105
106fn get_system_memory() -> usize {
112 let sys = System::new_all();
113 sys.total_memory() as usize
114}
115
116fn get_available_memory() -> usize {
118 let sys = System::new_all();
119 sys.available_memory() as usize
120}
121
122fn num_cpus() -> usize {
124 std::thread::available_parallelism()
125 .map(std::num::NonZeroUsize::get)
126 .unwrap_or(1)
127}
128
129impl CpuBackend {
134 pub fn add<T: Numeric>(dst: &mut [T], a: &[T], b: &[T]) {
136 debug_assert_eq!(a.len(), b.len());
137 debug_assert_eq!(a.len(), dst.len());
138
139 for i in 0..dst.len() {
140 dst[i] = a[i] + b[i];
141 }
142 }
143
144 pub fn sub<T: Numeric>(dst: &mut [T], a: &[T], b: &[T]) {
146 debug_assert_eq!(a.len(), b.len());
147 debug_assert_eq!(a.len(), dst.len());
148
149 for i in 0..dst.len() {
150 dst[i] = a[i] - b[i];
151 }
152 }
153
154 pub fn mul<T: Numeric>(dst: &mut [T], a: &[T], b: &[T]) {
156 debug_assert_eq!(a.len(), b.len());
157 debug_assert_eq!(a.len(), dst.len());
158
159 for i in 0..dst.len() {
160 dst[i] = a[i] * b[i];
161 }
162 }
163
164 pub fn div<T: Numeric>(dst: &mut [T], a: &[T], b: &[T]) {
166 debug_assert_eq!(a.len(), b.len());
167 debug_assert_eq!(a.len(), dst.len());
168
169 for i in 0..dst.len() {
170 dst[i] = a[i] / b[i];
171 }
172 }
173
174 pub fn add_scalar<T: Numeric>(dst: &mut [T], a: &[T], scalar: T) {
176 debug_assert_eq!(a.len(), dst.len());
177
178 for i in 0..dst.len() {
179 dst[i] = a[i] + scalar;
180 }
181 }
182
183 pub fn mul_scalar<T: Numeric>(dst: &mut [T], a: &[T], scalar: T) {
185 debug_assert_eq!(a.len(), dst.len());
186
187 for i in 0..dst.len() {
188 dst[i] = a[i] * scalar;
189 }
190 }
191
192 pub fn neg<T: Numeric>(dst: &mut [T], a: &[T]) {
194 debug_assert_eq!(a.len(), dst.len());
195
196 for i in 0..dst.len() {
197 dst[i] = T::zero() - a[i];
198 }
199 }
200
201 pub fn abs<T: Numeric>(dst: &mut [T], a: &[T]) {
203 debug_assert_eq!(a.len(), dst.len());
204
205 for i in 0..dst.len() {
206 dst[i] = if a[i] < T::zero() {
207 T::zero() - a[i]
208 } else {
209 a[i]
210 };
211 }
212 }
213}
214
215impl CpuBackend {
220 pub fn relu<T: Float>(dst: &mut [T], a: &[T]) {
222 debug_assert_eq!(a.len(), dst.len());
223
224 for i in 0..dst.len() {
225 dst[i] = if a[i] > T::zero() { a[i] } else { T::zero() };
226 }
227 }
228
229 pub fn sigmoid<T: Float>(dst: &mut [T], a: &[T]) {
231 debug_assert_eq!(a.len(), dst.len());
232
233 for i in 0..dst.len() {
234 dst[i] = T::one() / (T::one() + (-a[i]).exp_value());
235 }
236 }
237
238 pub fn tanh<T: Float>(dst: &mut [T], a: &[T]) {
240 debug_assert_eq!(a.len(), dst.len());
241
242 for i in 0..dst.len() {
243 dst[i] = a[i].tanh_value();
244 }
245 }
246
247 pub fn exp<T: Float>(dst: &mut [T], a: &[T]) {
249 debug_assert_eq!(a.len(), dst.len());
250
251 for i in 0..dst.len() {
252 dst[i] = a[i].exp_value();
253 }
254 }
255
256 pub fn ln<T: Float>(dst: &mut [T], a: &[T]) {
258 debug_assert_eq!(a.len(), dst.len());
259
260 for i in 0..dst.len() {
261 dst[i] = a[i].ln_value();
262 }
263 }
264
265 pub fn sqrt<T: Float>(dst: &mut [T], a: &[T]) {
267 debug_assert_eq!(a.len(), dst.len());
268
269 for i in 0..dst.len() {
270 dst[i] = a[i].sqrt_value();
271 }
272 }
273
274 pub fn square<T: Numeric>(dst: &mut [T], a: &[T]) {
276 debug_assert_eq!(a.len(), dst.len());
277
278 for i in 0..dst.len() {
279 dst[i] = a[i] * a[i];
280 }
281 }
282}
283
284impl CpuBackend {
289 pub fn sum<T: Numeric>(a: &[T]) -> T {
291 let mut result = T::zero();
292 for &val in a {
293 result = result + val;
294 }
295 result
296 }
297
298 pub fn prod<T: Numeric>(a: &[T]) -> T {
300 let mut result = T::one();
301 for &val in a {
302 result = result * val;
303 }
304 result
305 }
306
307 pub fn max<T: Numeric>(a: &[T]) -> Option<T> {
309 if a.is_empty() {
310 return None;
311 }
312
313 let mut result = a[0];
314 for &val in &a[1..] {
315 if val > result {
316 result = val;
317 }
318 }
319 Some(result)
320 }
321
322 pub fn min<T: Numeric>(a: &[T]) -> Option<T> {
324 if a.is_empty() {
325 return None;
326 }
327
328 let mut result = a[0];
329 for &val in &a[1..] {
330 if val < result {
331 result = val;
332 }
333 }
334 Some(result)
335 }
336
337 pub fn mean<T: Float>(a: &[T]) -> Option<T> {
339 if a.is_empty() {
340 return None;
341 }
342
343 let sum = Self::sum(a);
344 let len = T::from(a.len()).unwrap_or(T::one());
345 Some(sum / len)
346 }
347
348 pub fn argmax<T: Numeric>(a: &[T]) -> Option<usize> {
350 if a.is_empty() {
351 return None;
352 }
353
354 let mut max_idx = 0;
355 let mut max_val = a[0];
356 for (i, &val) in a.iter().enumerate().skip(1) {
357 if val > max_val {
358 max_val = val;
359 max_idx = i;
360 }
361 }
362 Some(max_idx)
363 }
364
365 pub fn argmin<T: Numeric>(a: &[T]) -> Option<usize> {
367 if a.is_empty() {
368 return None;
369 }
370
371 let mut min_idx = 0;
372 let mut min_val = a[0];
373 for (i, &val) in a.iter().enumerate().skip(1) {
374 if val < min_val {
375 min_val = val;
376 min_idx = i;
377 }
378 }
379 Some(min_idx)
380 }
381}
382
383impl CpuBackend {
388 pub fn matmul<T: Numeric>(c: &mut [T], a: &[T], b: &[T], m: usize, n: usize, k: usize) {
394 debug_assert_eq!(a.len(), m * k);
395 debug_assert_eq!(b.len(), k * n);
396 debug_assert_eq!(c.len(), m * n);
397
398 const BLOCK_SIZE: usize = 64;
401
402 for val in c.iter_mut() {
404 *val = T::zero();
405 }
406
407 for i0 in (0..m).step_by(BLOCK_SIZE) {
409 let i_end = (i0 + BLOCK_SIZE).min(m);
410 for p0 in (0..k).step_by(BLOCK_SIZE) {
411 let p_end = (p0 + BLOCK_SIZE).min(k);
412 for j0 in (0..n).step_by(BLOCK_SIZE) {
413 let j_end = (j0 + BLOCK_SIZE).min(n);
414
415 for i in i0..i_end {
417 for p in p0..p_end {
418 let a_val = a[i * k + p];
419 for j in j0..j_end {
420 c[i * n + j] = c[i * n + j] + a_val * b[p * n + j];
421 }
422 }
423 }
424 }
425 }
426 }
427 }
428
429 pub fn sgemm(
433 c: &mut [f32],
434 a: &[f32],
435 b: &[f32],
436 m: usize,
437 n: usize,
438 k: usize,
439 alpha: f32,
440 beta: f32,
441 ) {
442 debug_assert_eq!(a.len(), m * k);
443 debug_assert_eq!(b.len(), k * n);
444 debug_assert_eq!(c.len(), m * n);
445
446 unsafe {
447 matrixmultiply::sgemm(
448 m,
449 k,
450 n,
451 alpha,
452 a.as_ptr(),
453 k as isize,
454 1, b.as_ptr(),
456 n as isize,
457 1, beta,
459 c.as_mut_ptr(),
460 n as isize,
461 1, );
463 }
464 }
465
466 pub fn dgemm(
470 c: &mut [f64],
471 a: &[f64],
472 b: &[f64],
473 m: usize,
474 n: usize,
475 k: usize,
476 alpha: f64,
477 beta: f64,
478 ) {
479 debug_assert_eq!(a.len(), m * k);
480 debug_assert_eq!(b.len(), k * n);
481 debug_assert_eq!(c.len(), m * n);
482
483 unsafe {
484 matrixmultiply::dgemm(
485 m,
486 k,
487 n,
488 alpha,
489 a.as_ptr(),
490 k as isize,
491 1, b.as_ptr(),
493 n as isize,
494 1, beta,
496 c.as_mut_ptr(),
497 n as isize,
498 1, );
500 }
501 }
502
503 pub fn matmul_f32(c: &mut [f32], a: &[f32], b: &[f32], m: usize, n: usize, k: usize) {
505 Self::sgemm(c, a, b, m, n, k, 1.0, 0.0);
506 }
507
508 pub fn matmul_f64(c: &mut [f64], a: &[f64], b: &[f64], m: usize, n: usize, k: usize) {
510 Self::dgemm(c, a, b, m, n, k, 1.0, 0.0);
511 }
512
513 pub fn transpose<T: Scalar>(dst: &mut [T], src: &[T], rows: usize, cols: usize) {
517 debug_assert_eq!(src.len(), rows * cols);
518 debug_assert_eq!(dst.len(), rows * cols);
519
520 for i in 0..rows {
521 for j in 0..cols {
522 dst[j * rows + i] = src[i * cols + j];
523 }
524 }
525 }
526
527 pub fn dot<T: Numeric>(a: &[T], b: &[T]) -> T {
529 debug_assert_eq!(a.len(), b.len());
530
531 let mut sum = T::zero();
532 for i in 0..a.len() {
533 sum = sum + a[i] * b[i];
534 }
535 sum
536 }
537}
538
539impl CpuBackend {
544 pub fn eq<T: Scalar + PartialEq>(dst: &mut [bool], a: &[T], b: &[T]) {
546 debug_assert_eq!(a.len(), b.len());
547 debug_assert_eq!(a.len(), dst.len());
548
549 for i in 0..dst.len() {
550 dst[i] = a[i] == b[i];
551 }
552 }
553
554 pub fn lt<T: Numeric>(dst: &mut [bool], a: &[T], b: &[T]) {
556 debug_assert_eq!(a.len(), b.len());
557 debug_assert_eq!(a.len(), dst.len());
558
559 for i in 0..dst.len() {
560 dst[i] = a[i] < b[i];
561 }
562 }
563
564 pub fn gt<T: Numeric>(dst: &mut [bool], a: &[T], b: &[T]) {
566 debug_assert_eq!(a.len(), b.len());
567 debug_assert_eq!(a.len(), dst.len());
568
569 for i in 0..dst.len() {
570 dst[i] = a[i] > b[i];
571 }
572 }
573}
574
575impl CpuBackend {
580 pub fn fill<T: Scalar>(dst: &mut [T], value: T) {
582 for elem in dst.iter_mut() {
583 *elem = value;
584 }
585 }
586
587 pub fn fill_zeros<T: Scalar>(dst: &mut [T]) {
589 Self::fill(dst, T::zeroed());
590 }
591
592 pub fn copy<T: Scalar>(dst: &mut [T], src: &[T]) {
594 debug_assert_eq!(dst.len(), src.len());
595 dst.copy_from_slice(src);
596 }
597}
598
599#[cfg(test)]
604mod tests {
605 use super::*;
606
607 #[test]
608 fn test_add() {
609 let a = [1.0_f32, 2.0, 3.0];
610 let b = [4.0_f32, 5.0, 6.0];
611 let mut c = [0.0_f32; 3];
612
613 CpuBackend::add(&mut c, &a, &b);
614 assert_eq!(c, [5.0, 7.0, 9.0]);
615 }
616
617 #[test]
618 fn test_mul() {
619 let a = [2.0_f32, 3.0, 4.0];
620 let b = [2.0_f32, 2.0, 2.0];
621 let mut c = [0.0_f32; 3];
622
623 CpuBackend::mul(&mut c, &a, &b);
624 assert_eq!(c, [4.0, 6.0, 8.0]);
625 }
626
627 #[test]
628 fn test_relu() {
629 let a = [-1.0_f32, 0.0, 1.0, 2.0];
630 let mut b = [0.0_f32; 4];
631
632 CpuBackend::relu(&mut b, &a);
633 assert_eq!(b, [0.0, 0.0, 1.0, 2.0]);
634 }
635
636 #[test]
637 fn test_sum() {
638 let a = [1.0_f32, 2.0, 3.0, 4.0];
639 assert_eq!(CpuBackend::sum(&a), 10.0);
640 }
641
642 #[test]
643 fn test_max_min() {
644 let a = [1.0_f32, 4.0, 2.0, 3.0];
645 assert_eq!(CpuBackend::max(&a), Some(4.0));
646 assert_eq!(CpuBackend::min(&a), Some(1.0));
647 }
648
649 #[test]
650 fn test_argmax() {
651 let a = [1.0_f32, 4.0, 2.0, 3.0];
652 assert_eq!(CpuBackend::argmax(&a), Some(1));
653 }
654
655 #[test]
656 fn test_matmul() {
657 let a = [1.0_f32, 2.0, 3.0, 4.0];
661 let b = [5.0_f32, 6.0, 7.0, 8.0];
662 let mut c = [0.0_f32; 4];
663
664 CpuBackend::matmul(&mut c, &a, &b, 2, 2, 2);
665 assert_eq!(c, [19.0, 22.0, 43.0, 50.0]);
666 }
667
668 #[test]
669 fn test_transpose() {
670 let a = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
673 let mut b = [0.0_f32; 6];
674
675 CpuBackend::transpose(&mut b, &a, 2, 3);
676 assert_eq!(b, [1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
677 }
678
679 #[test]
680 fn test_dot() {
681 let a = [1.0_f32, 2.0, 3.0];
682 let b = [4.0_f32, 5.0, 6.0];
683 assert_eq!(CpuBackend::dot(&a, &b), 32.0);
684 }
685
686 #[test]
687 fn test_fill() {
688 let mut a = [0.0_f32; 5];
689 CpuBackend::fill(&mut a, 42.0);
690 assert_eq!(a, [42.0; 5]);
691 }
692}