use pyo3::prelude::*;
#[pyfunction]
pub fn parallel_map_mean(py: Python<'_>, arrays: Vec<Vec<f64>>) -> PyResult<Vec<f64>> {
let result = py.detach(|| {
use rayon::prelude::*;
arrays
.par_iter()
.map(|arr| {
if arr.is_empty() {
0.0
} else {
arr.iter().sum::<f64>() / arr.len() as f64
}
})
.collect::<Vec<f64>>()
});
Ok(result)
}
#[pyfunction]
pub fn parallel_batch_matvec(
py: Python<'_>,
matrices: Vec<Vec<f64>>,
vectors: Vec<Vec<f64>>,
n_rows: usize,
n_cols: usize,
) -> PyResult<Vec<Vec<f64>>> {
if matrices.len() != vectors.len() {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"matrices ({}) and vectors ({}) must have equal length",
matrices.len(),
vectors.len()
)));
}
let expected_mat = n_rows * n_cols;
for (i, mat) in matrices.iter().enumerate() {
if mat.len() != expected_mat {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"matrix[{i}] has {} elements but expected {expected_mat} (n_rows={n_rows}, n_cols={n_cols})",
mat.len()
)));
}
}
for (i, vec) in vectors.iter().enumerate() {
if vec.len() != n_cols {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"vector[{i}] has {} elements but expected {n_cols} (n_cols)",
vec.len()
)));
}
}
let result = py.detach(|| {
use rayon::prelude::*;
matrices
.par_iter()
.zip(vectors.par_iter())
.map(|(mat, vec)| {
(0..n_rows)
.map(|r| {
(0..n_cols)
.map(|c| mat[r * n_cols + c] * vec[c])
.sum::<f64>()
})
.collect::<Vec<f64>>()
})
.collect::<Vec<Vec<f64>>>()
});
Ok(result)
}
pub fn register_parallel_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(parallel_map_mean, m)?)?;
m.add_function(wrap_pyfunction!(parallel_batch_matvec, m)?)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parallel_map_mean_empty_arrays() {
pyo3::Python::initialize();
Python::attach(|py| {
let arrays: Vec<Vec<f64>> = vec![vec![], vec![1.0, 2.0, 3.0], vec![]];
let means = parallel_map_mean(py, arrays).expect("parallel_map_mean failed");
assert_eq!(means.len(), 3);
assert!((means[0]).abs() < f64::EPSILON); assert!((means[1] - 2.0).abs() < f64::EPSILON);
assert!((means[2]).abs() < f64::EPSILON); });
}
#[test]
fn parallel_batch_matvec_identity_matrix() {
pyo3::Python::initialize();
Python::attach(|py| {
let matrices = vec![vec![1.0, 0.0, 0.0, 1.0], vec![2.0, 0.0, 0.0, 2.0]];
let vectors = vec![vec![3.0, 4.0], vec![3.0, 4.0]];
let results =
parallel_batch_matvec(py, matrices, vectors, 2, 2).expect("matvec failed");
assert_eq!(results.len(), 2);
assert!((results[0][0] - 3.0).abs() < f64::EPSILON);
assert!((results[0][1] - 4.0).abs() < f64::EPSILON);
assert!((results[1][0] - 6.0).abs() < f64::EPSILON);
assert!((results[1][1] - 8.0).abs() < f64::EPSILON);
});
}
#[test]
fn parallel_batch_matvec_mismatched_lengths_errors() {
pyo3::Python::initialize();
Python::attach(|py| {
let matrices = vec![vec![1.0, 0.0, 0.0, 1.0]];
let vectors = vec![vec![1.0], vec![2.0]]; let result = parallel_batch_matvec(py, matrices, vectors, 2, 2);
assert!(result.is_err());
});
}
}