use anyhow::{Result, anyhow};
use crate::log::StageLogger;
use std::sync::{Mutex, MutexGuard};
pub fn lock_recover<'a, T>(m: &'a Mutex<T>, log: &StageLogger, label: &str) -> MutexGuard<'a, T> {
match m.lock() {
Ok(g) => g,
Err(poisoned) => {
log.warn(&format!(
"{label}: mutex poisoned by sibling thread panic; recovering state"
));
poisoned.into_inner()
}
}
}
pub fn join_panic_to_err<T>(join_result: std::thread::Result<T>, label: &str) -> Result<T> {
join_result.map_err(|panic_payload| {
let msg = if let Some(s) = panic_payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = panic_payload.downcast_ref::<String>() {
s.clone()
} else {
format!("{:?}", panic_payload)
};
anyhow!("{label} worker thread panicked: {msg}")
})
}
pub fn run_parallel_chunks<J, T, F>(
jobs: &[J],
parallelism: usize,
stage_name: &'static str,
run_job: F,
) -> Result<Vec<T>>
where
J: Sync,
T: Send,
F: Fn(&J) -> Result<T> + Sync,
{
let parallelism = parallelism.max(1);
let mut results: Vec<T> = Vec::with_capacity(jobs.len());
for chunk in jobs.chunks(parallelism) {
let chunk_results: Vec<Result<T>> = std::thread::scope(|s| {
let handles: Vec<_> = chunk.iter().map(|job| s.spawn(|| run_job(job))).collect();
handles
.into_iter()
.map(|h| {
h.join()
.unwrap_or_else(|_| Err(anyhow!("{} worker thread panicked", stage_name)))
})
.collect()
});
for r in chunk_results {
results.push(r?);
}
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn preserves_submission_order() {
let jobs: Vec<u32> = (0..20).collect();
let out = run_parallel_chunks(&jobs, 4, "test", |job| Ok(*job * 10)).unwrap();
assert_eq!(out, (0..20).map(|i| i * 10).collect::<Vec<_>>());
}
#[test]
fn bounded_concurrency() {
let jobs: Vec<u32> = (0..10).collect();
let in_flight = AtomicUsize::new(0);
let peak = AtomicUsize::new(0);
run_parallel_chunks(&jobs, 2, "test", |_| {
let now = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
peak.fetch_max(now, Ordering::SeqCst);
std::thread::sleep(std::time::Duration::from_millis(10));
in_flight.fetch_sub(1, Ordering::SeqCst);
Ok(())
})
.unwrap();
assert!(
peak.load(Ordering::SeqCst) <= 2,
"peak in-flight workers exceeded parallelism bound"
);
}
#[test]
fn propagates_first_error() {
let jobs: Vec<u32> = (0..4).collect();
let result = run_parallel_chunks(&jobs, 2, "test", |job| {
if *job == 2 {
Err(anyhow!("job 2 failed"))
} else {
Ok(*job)
}
});
let err = result.unwrap_err();
assert!(
err.to_string().contains("job 2 failed"),
"unexpected error: {}",
err
);
}
#[test]
fn zero_parallelism_clamps_to_one() {
let jobs: Vec<u32> = (0..3).collect();
let out = run_parallel_chunks(&jobs, 0, "test", |job| Ok(*job + 1)).unwrap();
assert_eq!(out, vec![1, 2, 3]);
}
#[test]
fn empty_jobs_returns_empty() {
let out: Vec<u32> = run_parallel_chunks::<u32, u32, _>(&[], 4, "test", |_| Ok(0)).unwrap();
assert!(out.is_empty());
}
#[test]
fn panic_in_worker_becomes_anyhow_error() {
let jobs: Vec<u32> = vec![1, 2, 3];
let result = run_parallel_chunks(&jobs, 2, "explode-stage", |job| -> Result<u32> {
if *job == 2 {
panic!("boom");
}
Ok(*job)
});
let err = result.unwrap_err();
assert!(
err.to_string()
.contains("explode-stage worker thread panicked"),
"unexpected error: {}",
err
);
}
#[test]
fn lock_recover_returns_inner_when_unpoisoned() {
let log = StageLogger::new("test", crate::log::Verbosity::Quiet);
let m = Mutex::new(0u32);
{
let mut g = lock_recover(&m, &log, "test");
*g = 42;
}
assert_eq!(*m.lock().unwrap(), 42);
}
#[test]
fn lock_recover_recovers_from_poison() {
let log = StageLogger::new("test", crate::log::Verbosity::Quiet);
let m = std::sync::Arc::new(Mutex::new(7u32));
let m_for_thread = std::sync::Arc::clone(&m);
let h = std::thread::spawn(move || {
let _g = m_for_thread.lock().unwrap();
panic!("poison the mutex");
});
let _ = h.join();
assert!(m.is_poisoned(), "test setup: mutex should be poisoned");
let g = lock_recover(&m, &log, "test");
assert_eq!(*g, 7);
}
#[test]
fn join_panic_to_err_passes_through_success() {
let h = std::thread::spawn(|| 42u32);
let r = join_panic_to_err(h.join(), "worker").unwrap();
assert_eq!(r, 42);
}
#[test]
fn join_panic_to_err_translates_str_panic() {
let h = std::thread::spawn(|| -> u32 {
panic!("kaboom");
});
let err = join_panic_to_err(h.join(), "worker").unwrap_err();
let s = err.to_string();
assert!(
s.contains("worker worker thread panicked") && s.contains("kaboom"),
"unexpected error: {}",
s
);
}
#[test]
fn join_panic_to_err_translates_string_panic() {
let h = std::thread::spawn(|| -> u32 {
panic!("{}", String::from("string-panic"));
});
let err = join_panic_to_err(h.join(), "worker").unwrap_err();
assert!(
err.to_string().contains("string-panic"),
"unexpected error: {}",
err
);
}
#[test]
fn join_panic_to_err_works_on_scoped_handle() {
let out: Result<u32> = std::thread::scope(|s| {
let h = s.spawn(|| 99u32);
join_panic_to_err(h.join(), "scoped")
});
assert_eq!(out.unwrap(), 99);
}
}