use parking_lot;
use scirs2_core::ndarray::{ArrayBase, Data, Dimension};
use scirs2_core::parallel_ops::*;
use std::sync::Arc;
use crate::error::Result;
pub type ParallelMetricFn<S1, S2, D1, D2> =
dyn Fn(&ArrayBase<S1, D1>, &ArrayBase<S2, D2>) -> Result<f64> + Send + Sync;
#[derive(Debug, Clone)]
pub struct ParallelConfig {
pub min_chunk_size: usize,
pub parallel_enabled: bool,
pub num_threads: Option<usize>,
}
impl Default for ParallelConfig {
fn default() -> Self {
ParallelConfig {
min_chunk_size: 1000,
parallel_enabled: true,
num_threads: None,
}
}
}
impl ParallelConfig {
pub fn new() -> Self {
Default::default()
}
pub fn with_min_chunk_size(mut self, size: usize) -> Self {
self.min_chunk_size = size;
self
}
pub fn with_parallel_enabled(mut self, enabled: bool) -> Self {
self.parallel_enabled = enabled;
self
}
pub fn with_num_threads(mut self, threads: Option<usize>) -> Self {
self.num_threads = threads;
self
}
}
pub trait ParallelMetric<T, D>
where
T: Send + Sync,
D: Dimension,
{
fn compute_parallel(
&self,
x: &ArrayBase<impl Data<Elem = T>, D>,
config: &ParallelConfig,
) -> Result<f64>;
}
#[allow(dead_code)]
pub fn compute_metrics_batch<T, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
metric_fns: &[Box<ParallelMetricFn<S1, S2, D1, D2>>],
config: &ParallelConfig,
) -> Result<Vec<f64>>
where
T: Clone + Send + Sync,
S1: Data<Elem = T> + Sync,
S2: Data<Elem = T> + Sync,
D1: Dimension + Sync,
D2: Dimension + Sync,
{
if !config.parallel_enabled || metric_fns.len() < 2 {
let mut results = Vec::with_capacity(metric_fns.len());
for metric_fn in metric_fns {
let value = metric_fn(y_true, y_pred)?;
results.push(value);
}
return Ok(results);
}
let results: Result<Vec<f64>> = metric_fns
.par_iter()
.map(|metric_fn| metric_fn(y_true, y_pred))
.collect();
results
}
#[allow(dead_code)]
pub fn chunked_parallel_compute<T, R>(
data: &[T],
chunk_size: usize,
chunk_op: impl Fn(&[T]) -> Result<R> + Send + Sync,
reducer: impl Fn(Vec<R>) -> Result<R>,
) -> Result<R>
where
T: Clone + Send + Sync,
R: Send + Sync,
{
if data.len() <= chunk_size {
return chunk_op(data);
}
let chunks: Vec<&[T]> = data.chunks(chunk_size).collect();
let results: Result<Vec<R>> = chunks.par_iter().map(|chunk| chunk_op(chunk)).collect();
reducer(results?)
}
pub trait ChunkedMetric<T> {
type State: Send + Sync;
fn init_state(&self) -> Self::State;
fn process_chunk(&self, state: &mut Self::State, chunk: &[T]) -> Result<()>;
fn finalize(&self, state: &Self::State) -> Result<f64>;
}
#[allow(dead_code)]
pub fn compute_chunked_metric<T, M>(
data: &[T],
metric: &M,
chunk_size: usize,
config: &ParallelConfig,
) -> Result<f64>
where
T: Clone + Send + Sync,
M: ChunkedMetric<T> + Send + Sync,
{
if data.len() <= chunk_size || !config.parallel_enabled {
let mut state = metric.init_state();
metric.process_chunk(&mut state, data)?;
return metric.finalize(&state);
}
let state = Arc::new(parking_lot::Mutex::new(metric.init_state()));
let metric = Arc::new(metric);
let chunks: Vec<&[T]> = data.chunks(chunk_size).collect();
let result: Result<()> = chunks.par_iter().try_for_each(|chunk| {
let mut local_state = metric.init_state();
metric.process_chunk(&mut local_state, chunk)?;
let mut global_state = state.lock();
metric.process_chunk(&mut *global_state, chunk)?;
Ok(())
});
result?;
let state_lock = state.lock();
let result = metric.finalize(&*state_lock);
drop(state_lock); result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::MetricsError;
use scirs2_core::ndarray::Array1;
#[test]
fn test_parallel_config() {
let config = ParallelConfig::new()
.with_min_chunk_size(500)
.with_parallel_enabled(true)
.with_num_threads(Some(4));
assert_eq!(config.min_chunk_size, 500);
assert!(config.parallel_enabled);
assert_eq!(config.num_threads, Some(4));
}
#[test]
fn test_compute_metrics_batch() {
let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2]);
let y_pred = Array1::from_vec(vec![0, 2, 1, 0, 0, 2]);
type MetricFn = Box<dyn Fn(&Array1<i32>, &Array1<i32>) -> Result<f64> + Send + Sync>;
let metric_fns: Vec<MetricFn> = vec![
Box::new(|a, b| {
if a.len() != b.len() {
return Err(MetricsError::InvalidInput("Lengths must match".to_string()));
}
let correct = a.iter().zip(b.iter()).filter(|&(a, b)| a == b).count();
Ok(correct as f64 / a.len() as f64)
}),
Box::new(|a, _b| {
Ok(a.len() as f64)
}),
];
let config = ParallelConfig::new().with_parallel_enabled(false);
let results = compute_metrics_batch(&y_true, &y_pred, &metric_fns, &config)
.expect("Operation failed");
assert_eq!(results.len(), 2);
assert!((results[0] - 0.5).abs() < 1e-10); assert!((results[1] - 6.0).abs() < 1e-10);
let config = ParallelConfig::new().with_parallel_enabled(true);
let results = compute_metrics_batch(&y_true, &y_pred, &metric_fns, &config)
.expect("Operation failed");
assert_eq!(results.len(), 2);
assert!((results[0] - 0.5).abs() < 1e-10);
assert!((results[1] - 6.0).abs() < 1e-10);
}
#[test]
fn test_chunked_parallel_compute() {
let data: Vec<f64> = (0..1000).map(|x| x as f64).collect();
let chunk_op = |chunk: &[f64]| -> Result<f64> { Ok(chunk.iter().map(|x| x * x).sum()) };
let reducer = |results: Vec<f64>| -> Result<f64> { Ok(results.iter().sum()) };
let result =
chunked_parallel_compute(&data, 100, chunk_op, reducer).expect("Operation failed");
let expected: f64 = (0..1000).map(|x| (x * x) as f64).sum();
assert!((result - expected).abs() < 1e-10);
}
struct MeanChunkedMetric;
impl ChunkedMetric<f64> for MeanChunkedMetric {
type State = (f64, usize);
fn init_state(&self) -> Self::State {
(0.0, 0)
}
fn process_chunk(&self, state: &mut Self::State, chunk: &[f64]) -> Result<()> {
for &value in chunk {
state.0 += value;
state.1 += 1;
}
Ok(())
}
fn finalize(&self, state: &Self::State) -> Result<f64> {
if state.1 == 0 {
return Err(MetricsError::DivisionByZero);
}
Ok(state.0 / state.1 as f64)
}
}
#[test]
fn test_compute_chunked_metric() {
let data: Vec<f64> = (0..1000).map(|x| x as f64).collect();
let metric = MeanChunkedMetric;
let config = ParallelConfig::default();
let result =
compute_chunked_metric(&data, &metric, 100, &config).expect("Operation failed");
let expected: f64 = data.iter().sum::<f64>() / data.len() as f64;
assert!((result - expected).abs() < 1e-10);
}
}