numrs/backend/cpu/
parallel.rs

1//! Multi-threaded matmul usando Rayon
2//!
3//! Strategy: Dividir la matriz en bloques y procesar en paralelo con Rayon.
4//! IMPORTANTE: Esta función NO decide el backend, solo paraleliza.
5//! El kernel específico se pasa como parámetro.
6
7use crate::array::Array;
8use anyhow::{bail, Result};
9#[cfg(not(target_arch = "wasm32"))]
10use rayon::prelude::*;
11
12/// Matmul usando un kernel específico (BLAS, SIMD, o Scalar)
13/// Esta función NO decide qué backend usar - solo ejecuta el kernel dado
14// Native implementation using Rayon
15#[cfg(not(target_arch = "wasm32"))]
16pub fn matmul_with_kernel(
17    a: &Array,
18    b: &Array,
19    kernel: fn(&Array, &Array) -> Array,
20) -> Result<Array> {
21    // Validación
22    if a.shape.len() != 2 || b.shape.len() != 2 {
23        bail!("matmul requires 2-D arrays");
24    }
25
26    let m = a.shape[0];
27    let k = a.shape[1];
28    let n = b.shape[1];
29
30    if k != b.shape[0] {
31        bail!("inner dimension mismatch: {} != {}", k, b.shape[0]);
32    }
33
34    // Para matrices pequeñas, ejecutar directamente sin overhead de threading
35    if m < 500 {
36        return Ok(kernel(a, b));
37    }
38
39    // Determinar tamaño de bloques adaptativo
40    let num_threads = rayon::current_num_threads();
41
42    // Para matrices muy grandes (>2048), limitar el tamaño de bloque para evitar overhead de memoria
43    // Para matrices medianas, usar bloques más grandes para mejor aprovechamiento de cache
44    let block_size = if m >= 2048 {
45        // Matrices muy grandes: bloques más pequeños, más bloques
46        256.max((m + num_threads * 4 - 1) / (num_threads * 4))
47    } else {
48        // Matrices medianas: bloques balanceados
49        64.max((m + num_threads - 1) / num_threads)
50    };
51
52    // Pre-alocar el resultado completo para evitar concatenaciones costosas
53    let mut result = vec![0.0f32; m * n];
54
55    // Procesar bloques en paralelo escribiendo directamente al resultado
56    result
57        .par_chunks_mut(block_size * n)
58        .enumerate()
59        .for_each(|(block_idx, out_block)| {
60            let start = block_idx * block_size;
61            let end = (start + block_size).min(m);
62            let block_rows = end - start;
63
64            // Extraer sub-matriz de A (sin copiar toda la fila, solo lo necesario)
65            let a_block_data: Vec<f32> = (start..end)
66                .flat_map(|i| &a.data[i * k..(i + 1) * k])
67                .copied()
68                .collect();
69
70            let a_block = Array::new(vec![block_rows, k], a_block_data);
71
72            // Ejecutar matmul del bloque usando el kernel dado
73            let result_block = kernel(&a_block, b);
74
75            // Copiar resultado directamente al slice de salida
76            out_block[..result_block.data.len()].copy_from_slice(&result_block.data);
77        });
78
79    Ok(Array::new(vec![m, n], result))
80}
81
82// WASM implementation (Serial) - Rayon panics on WASM so we force serial execution
83#[cfg(target_arch = "wasm32")]
84pub fn matmul_with_kernel(
85    a: &Array,
86    b: &Array,
87    kernel: fn(&Array, &Array) -> Array,
88) -> Result<Array> {
89    // En WASM simplemente llamamos al kernel directamente (ya sea BLAS simulado o SIMD)
90    // No vale la pena el overhead de blocking/splitting si es single-thread
91    Ok(kernel(a, b))
92}
93
94/// Función legacy para compatibilidad - usa BLAS si está disponible
95#[cfg(numrs_has_blas)]
96pub fn matmul_parallel(a: &Array, b: &Array) -> Result<Array> {
97    matmul_with_kernel(a, b, crate::backend::blas::matmul_blas)
98}
99
100/// Función legacy para compatibilidad - usa SIMD si BLAS no está disponible
101#[cfg(not(numrs_has_blas))]
102pub fn matmul_parallel(a: &Array, b: &Array) -> Result<Array> {
103    matmul_with_kernel(a, b, super::matmul_simd_direct)
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    #[test]
111    fn test_parallel_matmul() {
112        let a = Array::new(vec![1000, 1000], vec![1.0; 1000 * 1000]);
113        let b = Array::new(vec![1000, 1000], vec![1.0; 1000 * 1000]);
114
115        let result = matmul_parallel(&a, &b).unwrap();
116
117        assert_eq!(result.shape, vec![1000, 1000]);
118        // Cada elemento debería ser suma de 1000 unos = 1000.0
119        assert!((result.data[0] - 1000.0).abs() < 0.1);
120    }
121}