Skip to main content

scirs2/
parallel.rs

1//! Parallel batch processing utilities.
2//!
3//! Provides Python-facing functions that run Rayon-based parallel computations
4//! while releasing the Python GIL so that Python threads are not blocked.
5//!
6//! # Example (Python)
7//! ```python
8//! import scirs2
9//!
10//! # Compute means of multiple float arrays in parallel.
11//! arrays = [[1.0, 2.0, 3.0], [4.0, 5.0], [6.0, 7.0, 8.0, 9.0]]
12//! means = scirs2.parallel_map_mean(arrays)
13//! # means == [2.0, 4.5, 7.5]
14//!
15//! # Parallel batch matrix-vector multiply.
16//! mats = [[1.0, 0.0, 0.0, 1.0], [2.0, 0.0, 0.0, 2.0]]
17//! vecs = [[3.0, 4.0], [3.0, 4.0]]
18//! results = scirs2.parallel_batch_matvec(mats, vecs, 2, 2)
19//! # results == [[3.0, 4.0], [6.0, 8.0]]
20//! ```
21
22use pyo3::prelude::*;
23
24// ──────────────────────────────────────────────────────────────────────────────
25// parallel_map_mean
26// ──────────────────────────────────────────────────────────────────────────────
27
28/// Compute the arithmetic mean of each sub-array in parallel.
29///
30/// The GIL is released during Rayon computation, allowing Python threads to run
31/// concurrently.
32///
33/// Empty sub-arrays return `0.0`.
34///
35/// # Arguments
36/// * `arrays` – list of float arrays, one mean value is produced per element.
37///
38/// # Returns
39/// A list of `f64` mean values with the same length as `arrays`.
40#[pyfunction]
41pub fn parallel_map_mean(py: Python<'_>, arrays: Vec<Vec<f64>>) -> PyResult<Vec<f64>> {
42    let result = py.detach(|| {
43        use rayon::prelude::*;
44        arrays
45            .par_iter()
46            .map(|arr| {
47                if arr.is_empty() {
48                    0.0
49                } else {
50                    arr.iter().sum::<f64>() / arr.len() as f64
51                }
52            })
53            .collect::<Vec<f64>>()
54    });
55    Ok(result)
56}
57
58// ──────────────────────────────────────────────────────────────────────────────
59// parallel_batch_matvec
60// ──────────────────────────────────────────────────────────────────────────────
61
62/// Perform a batch of matrix-vector multiplications in parallel.
63///
64/// Each pair `(matrices[i], vectors[i])` represents a multiplication
65/// `A * v` where `A` is stored in row-major order with shape
66/// `(n_rows, n_cols)` and `v` is a vector of length `n_cols`.
67///
68/// The GIL is released during Rayon computation.
69///
70/// # Arguments
71/// * `matrices` – list of flat row-major matrices, each with `n_rows * n_cols` elements.
72/// * `vectors`  – list of vectors, each with `n_cols` elements.
73/// * `n_rows`   – number of rows in each matrix.
74/// * `n_cols`   – number of columns in each matrix (= length of each vector).
75///
76/// # Errors
77/// Returns `ValueError` if:
78/// - `matrices` and `vectors` have different lengths.
79/// - Any matrix does not have exactly `n_rows * n_cols` elements.
80/// - Any vector does not have exactly `n_cols` elements.
81#[pyfunction]
82pub fn parallel_batch_matvec(
83    py: Python<'_>,
84    matrices: Vec<Vec<f64>>,
85    vectors: Vec<Vec<f64>>,
86    n_rows: usize,
87    n_cols: usize,
88) -> PyResult<Vec<Vec<f64>>> {
89    if matrices.len() != vectors.len() {
90        return Err(pyo3::exceptions::PyValueError::new_err(format!(
91            "matrices ({}) and vectors ({}) must have equal length",
92            matrices.len(),
93            vectors.len()
94        )));
95    }
96
97    let expected_mat = n_rows * n_cols;
98    for (i, mat) in matrices.iter().enumerate() {
99        if mat.len() != expected_mat {
100            return Err(pyo3::exceptions::PyValueError::new_err(format!(
101                "matrix[{i}] has {} elements but expected {expected_mat} (n_rows={n_rows}, n_cols={n_cols})",
102                mat.len()
103            )));
104        }
105    }
106    for (i, vec) in vectors.iter().enumerate() {
107        if vec.len() != n_cols {
108            return Err(pyo3::exceptions::PyValueError::new_err(format!(
109                "vector[{i}] has {} elements but expected {n_cols} (n_cols)",
110                vec.len()
111            )));
112        }
113    }
114
115    let result = py.detach(|| {
116        use rayon::prelude::*;
117        matrices
118            .par_iter()
119            .zip(vectors.par_iter())
120            .map(|(mat, vec)| {
121                (0..n_rows)
122                    .map(|r| {
123                        (0..n_cols)
124                            .map(|c| mat[r * n_cols + c] * vec[c])
125                            .sum::<f64>()
126                    })
127                    .collect::<Vec<f64>>()
128            })
129            .collect::<Vec<Vec<f64>>>()
130    });
131    Ok(result)
132}
133
134// ──────────────────────────────────────────────────────────────────────────────
135// Module registration
136// ──────────────────────────────────────────────────────────────────────────────
137
138/// Register parallel batch-processing functions into a PyO3 module.
139///
140/// Exposes [`parallel_map_mean`] and [`parallel_batch_matvec`].
141pub fn register_parallel_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
142    m.add_function(wrap_pyfunction!(parallel_map_mean, m)?)?;
143    m.add_function(wrap_pyfunction!(parallel_batch_matvec, m)?)?;
144    Ok(())
145}
146
147// ──────────────────────────────────────────────────────────────────────────────
148// Tests
149// ──────────────────────────────────────────────────────────────────────────────
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn parallel_map_mean_empty_arrays() {
157        pyo3::Python::initialize();
158        Python::attach(|py| {
159            let arrays: Vec<Vec<f64>> = vec![vec![], vec![1.0, 2.0, 3.0], vec![]];
160            let means = parallel_map_mean(py, arrays).expect("parallel_map_mean failed");
161            assert_eq!(means.len(), 3);
162            assert!((means[0]).abs() < f64::EPSILON); // empty → 0.0
163            assert!((means[1] - 2.0).abs() < f64::EPSILON);
164            assert!((means[2]).abs() < f64::EPSILON); // empty → 0.0
165        });
166    }
167
168    #[test]
169    fn parallel_batch_matvec_identity_matrix() {
170        pyo3::Python::initialize();
171        Python::attach(|py| {
172            // 2×2 identity × [3.0, 4.0] should give [3.0, 4.0].
173            let matrices = vec![vec![1.0, 0.0, 0.0, 1.0], vec![2.0, 0.0, 0.0, 2.0]];
174            let vectors = vec![vec![3.0, 4.0], vec![3.0, 4.0]];
175            let results =
176                parallel_batch_matvec(py, matrices, vectors, 2, 2).expect("matvec failed");
177            assert_eq!(results.len(), 2);
178            assert!((results[0][0] - 3.0).abs() < f64::EPSILON);
179            assert!((results[0][1] - 4.0).abs() < f64::EPSILON);
180            assert!((results[1][0] - 6.0).abs() < f64::EPSILON);
181            assert!((results[1][1] - 8.0).abs() < f64::EPSILON);
182        });
183    }
184
185    #[test]
186    fn parallel_batch_matvec_mismatched_lengths_errors() {
187        pyo3::Python::initialize();
188        Python::attach(|py| {
189            let matrices = vec![vec![1.0, 0.0, 0.0, 1.0]];
190            let vectors = vec![vec![1.0], vec![2.0]]; // 2 vectors, 1 matrix → error
191            let result = parallel_batch_matvec(py, matrices, vectors, 2, 2);
192            assert!(result.is_err());
193        });
194    }
195}