use orx_parallel::*;
use std::cell::UnsafeCell;
const N: u64 = 10_000_000;
const MAX_NUM_THREADS: usize = 8;
fn fibonacci(n: u64) -> u64 {
let mut a = 0;
let mut b = 1;
for _ in 0..n {
let c = a + b;
a = b;
b = c;
}
a
}
#[derive(Default, Debug)]
struct ThreadMetrics {
thread_idx: usize,
num_items_handled: usize,
handled_42: bool,
num_filtered_out: usize,
}
struct ThreadMetricsWriter<'a> {
metrics_ref: &'a mut ThreadMetrics,
}
struct ComputationMetrics {
thread_metrics: UnsafeCell<[ThreadMetrics; MAX_NUM_THREADS]>,
}
impl ComputationMetrics {
fn new() -> Self {
let mut thread_metrics: [ThreadMetrics; MAX_NUM_THREADS] = Default::default();
for i in 0..MAX_NUM_THREADS {
thread_metrics[i].thread_idx = i;
}
Self {
thread_metrics: UnsafeCell::new(thread_metrics),
}
}
}
unsafe impl Sync for ComputationMetrics {}
impl ComputationMetrics {
unsafe fn create_for_thread<'a>(&self, thread_idx: usize) -> ThreadMetricsWriter<'a> {
let array = unsafe { &mut *self.thread_metrics.get() };
ThreadMetricsWriter {
metrics_ref: &mut array[thread_idx],
}
}
}
fn main() {
let mut metrics = ComputationMetrics::new();
let input: Vec<u64> = (0..N).collect();
let sum = input
.par()
.using(|t| unsafe { metrics.create_for_thread(t) })
.map(|m: &mut ThreadMetricsWriter<'_>, i| {
m.metrics_ref.num_items_handled += 1;
m.metrics_ref.handled_42 |= *i == 42;
fibonacci((*i % 50) + 1) % 100
})
.filter(|m, i| {
let is_even = i % 2 == 0;
if !is_even {
m.metrics_ref.num_filtered_out += 1;
}
is_even
})
.num_threads(MAX_NUM_THREADS)
.sum();
println!("\nINPUT-LEN = {N}");
println!("SUM = {sum}");
println!("\n\n");
println!("COLLECTED METRICS PER THREAD");
for metrics in metrics.thread_metrics.get_mut().iter() {
println!("* {metrics:?}");
}
let total_by_metrics: usize = metrics
.thread_metrics
.get_mut()
.iter()
.map(|x| x.num_items_handled)
.sum();
println!("\n-> total num_items_handled by collected metrics: {total_by_metrics:?}\n");
assert_eq!(N as usize, total_by_metrics);
}