numrs/backend/cpu/
parallel.rs1use crate::array::Array;
8use anyhow::{bail, Result};
9#[cfg(not(target_arch = "wasm32"))]
10use rayon::prelude::*;
11
12#[cfg(not(target_arch = "wasm32"))]
16pub fn matmul_with_kernel(
17 a: &Array,
18 b: &Array,
19 kernel: fn(&Array, &Array) -> Array,
20) -> Result<Array> {
21 if a.shape.len() != 2 || b.shape.len() != 2 {
23 bail!("matmul requires 2-D arrays");
24 }
25
26 let m = a.shape[0];
27 let k = a.shape[1];
28 let n = b.shape[1];
29
30 if k != b.shape[0] {
31 bail!("inner dimension mismatch: {} != {}", k, b.shape[0]);
32 }
33
34 if m < 500 {
36 return Ok(kernel(a, b));
37 }
38
39 let num_threads = rayon::current_num_threads();
41
42 let block_size = if m >= 2048 {
45 256.max((m + num_threads * 4 - 1) / (num_threads * 4))
47 } else {
48 64.max((m + num_threads - 1) / num_threads)
50 };
51
52 let mut result = vec![0.0f32; m * n];
54
55 result
57 .par_chunks_mut(block_size * n)
58 .enumerate()
59 .for_each(|(block_idx, out_block)| {
60 let start = block_idx * block_size;
61 let end = (start + block_size).min(m);
62 let block_rows = end - start;
63
64 let a_block_data: Vec<f32> = (start..end)
66 .flat_map(|i| &a.data[i * k..(i + 1) * k])
67 .copied()
68 .collect();
69
70 let a_block = Array::new(vec![block_rows, k], a_block_data);
71
72 let result_block = kernel(&a_block, b);
74
75 out_block[..result_block.data.len()].copy_from_slice(&result_block.data);
77 });
78
79 Ok(Array::new(vec![m, n], result))
80}
81
82#[cfg(target_arch = "wasm32")]
84pub fn matmul_with_kernel(
85 a: &Array,
86 b: &Array,
87 kernel: fn(&Array, &Array) -> Array,
88) -> Result<Array> {
89 Ok(kernel(a, b))
92}
93
94#[cfg(numrs_has_blas)]
96pub fn matmul_parallel(a: &Array, b: &Array) -> Result<Array> {
97 matmul_with_kernel(a, b, crate::backend::blas::matmul_blas)
98}
99
100#[cfg(not(numrs_has_blas))]
102pub fn matmul_parallel(a: &Array, b: &Array) -> Result<Array> {
103 matmul_with_kernel(a, b, super::matmul_simd_direct)
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 #[test]
111 fn test_parallel_matmul() {
112 let a = Array::new(vec![1000, 1000], vec![1.0; 1000 * 1000]);
113 let b = Array::new(vec![1000, 1000], vec![1.0; 1000 * 1000]);
114
115 let result = matmul_parallel(&a, &b).unwrap();
116
117 assert_eq!(result.shape, vec![1000, 1000]);
118 assert!((result.data[0] - 1000.0).abs() < 0.1);
120 }
121}