scirs2_metrics/optimization/
parallel.rs1use parking_lot;
6use scirs2_core::ndarray::{ArrayBase, Data, Dimension};
7use scirs2_core::parallel_ops::*;
8use std::sync::Arc;
9
10use crate::error::Result;
11
12pub type ParallelMetricFn<S1, S2, D1, D2> =
14 dyn Fn(&ArrayBase<S1, D1>, &ArrayBase<S2, D2>) -> Result<f64> + Send + Sync;
15
16#[derive(Debug, Clone)]
21pub struct ParallelConfig {
22 pub min_chunk_size: usize,
24 pub parallel_enabled: bool,
26 pub num_threads: Option<usize>,
28}
29
30impl Default for ParallelConfig {
31 fn default() -> Self {
32 ParallelConfig {
33 min_chunk_size: 1000,
34 parallel_enabled: true,
35 num_threads: None,
36 }
37 }
38}
39
40impl ParallelConfig {
41 pub fn new() -> Self {
43 Default::default()
44 }
45
46 pub fn with_min_chunk_size(mut self, size: usize) -> Self {
48 self.min_chunk_size = size;
49 self
50 }
51
52 pub fn with_parallel_enabled(mut self, enabled: bool) -> Self {
54 self.parallel_enabled = enabled;
55 self
56 }
57
58 pub fn with_num_threads(mut self, threads: Option<usize>) -> Self {
60 self.num_threads = threads;
61 self
62 }
63}
64
65pub trait ParallelMetric<T, D>
67where
68 T: Send + Sync,
69 D: Dimension,
70{
71 fn compute_parallel(
73 &self,
74 x: &ArrayBase<impl Data<Elem = T>, D>,
75 config: &ParallelConfig,
76 ) -> Result<f64>;
77}
78
79#[allow(dead_code)]
120pub fn compute_metrics_batch<T, S1, S2, D1, D2>(
121 y_true: &ArrayBase<S1, D1>,
122 y_pred: &ArrayBase<S2, D2>,
123 metric_fns: &[Box<ParallelMetricFn<S1, S2, D1, D2>>],
124 config: &ParallelConfig,
125) -> Result<Vec<f64>>
126where
127 T: Clone + Send + Sync,
128 S1: Data<Elem = T> + Sync,
129 S2: Data<Elem = T> + Sync,
130 D1: Dimension + Sync,
131 D2: Dimension + Sync,
132{
133 if !config.parallel_enabled || metric_fns.len() < 2 {
134 let mut results = Vec::with_capacity(metric_fns.len());
136 for metric_fn in metric_fns {
137 let value = metric_fn(y_true, y_pred)?;
138 results.push(value);
139 }
140 return Ok(results);
141 }
142
143 let results: Result<Vec<f64>> = metric_fns
145 .par_iter()
146 .map(|metric_fn| metric_fn(y_true, y_pred))
147 .collect();
148
149 results
150}
151
152#[allow(dead_code)]
194pub fn chunked_parallel_compute<T, R>(
195 data: &[T],
196 chunk_size: usize,
197 chunk_op: impl Fn(&[T]) -> Result<R> + Send + Sync,
198 reducer: impl Fn(Vec<R>) -> Result<R>,
199) -> Result<R>
200where
201 T: Clone + Send + Sync,
202 R: Send + Sync,
203{
204 if data.len() <= chunk_size {
205 return chunk_op(data);
207 }
208
209 let chunks: Vec<&[T]> = data.chunks(chunk_size).collect();
211
212 let results: Result<Vec<R>> = chunks.par_iter().map(|chunk| chunk_op(chunk)).collect();
214
215 reducer(results?)
217}
218
219pub trait ChunkedMetric<T> {
221 type State: Send + Sync;
223
224 fn init_state(&self) -> Self::State;
226
227 fn process_chunk(&self, state: &mut Self::State, chunk: &[T]) -> Result<()>;
229
230 fn finalize(&self, state: &Self::State) -> Result<f64>;
232}
233
234#[allow(dead_code)]
247pub fn compute_chunked_metric<T, M>(
248 data: &[T],
249 metric: &M,
250 chunk_size: usize,
251 config: &ParallelConfig,
252) -> Result<f64>
253where
254 T: Clone + Send + Sync,
255 M: ChunkedMetric<T> + Send + Sync,
256{
257 if data.len() <= chunk_size || !config.parallel_enabled {
258 let mut state = metric.init_state();
260 metric.process_chunk(&mut state, data)?;
261 return metric.finalize(&state);
262 }
263
264 let state = Arc::new(parking_lot::Mutex::new(metric.init_state()));
266 let metric = Arc::new(metric);
267
268 let chunks: Vec<&[T]> = data.chunks(chunk_size).collect();
270
271 let result: Result<()> = chunks.par_iter().try_for_each(|chunk| {
273 let mut local_state = metric.init_state();
274 metric.process_chunk(&mut local_state, chunk)?;
275
276 let mut global_state = state.lock();
278 metric.process_chunk(&mut *global_state, chunk)?;
279 Ok(())
280 });
281
282 result?;
284
285 let state_lock = state.lock();
287 let result = metric.finalize(&*state_lock);
288 drop(state_lock); result
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use crate::error::MetricsError;
296 use scirs2_core::ndarray::Array1;
297
298 #[test]
299 fn test_parallel_config() {
300 let config = ParallelConfig::new()
301 .with_min_chunk_size(500)
302 .with_parallel_enabled(true)
303 .with_num_threads(Some(4));
304
305 assert_eq!(config.min_chunk_size, 500);
306 assert!(config.parallel_enabled);
307 assert_eq!(config.num_threads, Some(4));
308 }
309
310 #[test]
311 fn test_compute_metrics_batch() {
312 let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2]);
314 let y_pred = Array1::from_vec(vec![0, 2, 1, 0, 0, 2]);
315
316 type MetricFn = Box<dyn Fn(&Array1<i32>, &Array1<i32>) -> Result<f64> + Send + Sync>;
318 let metric_fns: Vec<MetricFn> = vec![
319 Box::new(|a, b| {
320 if a.len() != b.len() {
321 return Err(MetricsError::InvalidInput("Lengths must match".to_string()));
322 }
323 let correct = a.iter().zip(b.iter()).filter(|&(a, b)| a == b).count();
325 Ok(correct as f64 / a.len() as f64)
326 }),
327 Box::new(|a, _b| {
328 Ok(a.len() as f64)
330 }),
331 ];
332
333 let config = ParallelConfig::new().with_parallel_enabled(false);
335 let results = compute_metrics_batch(&y_true, &y_pred, &metric_fns, &config).unwrap();
336
337 assert_eq!(results.len(), 2);
338 assert!((results[0] - 0.5).abs() < 1e-10); assert!((results[1] - 6.0).abs() < 1e-10); let config = ParallelConfig::new().with_parallel_enabled(true);
343 let results = compute_metrics_batch(&y_true, &y_pred, &metric_fns, &config).unwrap();
344
345 assert_eq!(results.len(), 2);
346 assert!((results[0] - 0.5).abs() < 1e-10);
347 assert!((results[1] - 6.0).abs() < 1e-10);
348 }
349
350 #[test]
351 fn test_chunked_parallel_compute() {
352 let data: Vec<f64> = (0..1000).map(|x| x as f64).collect();
354
355 let chunk_op = |chunk: &[f64]| -> Result<f64> { Ok(chunk.iter().map(|x| x * x).sum()) };
357
358 let reducer = |results: Vec<f64>| -> Result<f64> { Ok(results.iter().sum()) };
360
361 let result = chunked_parallel_compute(&data, 100, chunk_op, reducer).unwrap();
363
364 let expected: f64 = (0..1000).map(|x| (x * x) as f64).sum();
366 assert!((result - expected).abs() < 1e-10);
367 }
368
369 struct MeanChunkedMetric;
371
372 impl ChunkedMetric<f64> for MeanChunkedMetric {
373 type State = (f64, usize); fn init_state(&self) -> Self::State {
376 (0.0, 0)
377 }
378
379 fn process_chunk(&self, state: &mut Self::State, chunk: &[f64]) -> Result<()> {
380 for &value in chunk {
381 state.0 += value;
382 state.1 += 1;
383 }
384 Ok(())
385 }
386
387 fn finalize(&self, state: &Self::State) -> Result<f64> {
388 if state.1 == 0 {
389 return Err(MetricsError::DivisionByZero);
390 }
391 Ok(state.0 / state.1 as f64)
392 }
393 }
394
395 #[test]
396 fn test_compute_chunked_metric() {
397 let data: Vec<f64> = (0..1000).map(|x| x as f64).collect();
399
400 let metric = MeanChunkedMetric;
402
403 let config = ParallelConfig::default();
405 let result = compute_chunked_metric(&data, &metric, 100, &config).unwrap();
406
407 let expected: f64 = data.iter().sum::<f64>() / data.len() as f64;
409 assert!((result - expected).abs() < 1e-10);
410 }
411}