use std::any::Any;
use std::panic::AssertUnwindSafe;
use std::sync::Arc;
use std::thread::{self, JoinHandle};
use super::SpinBarrier;
pub fn panic_message(err: Option<&Box<dyn Any + Send>>) -> String {
match err {
Some(payload) => payload
.downcast_ref::<String>()
.cloned()
.or_else(|| {
payload
.downcast_ref::<&'static str>()
.map(|s| s.to_string())
})
.unwrap_or_else(|| "<non-string panic payload>".to_string()),
None => String::new(),
}
}
pub fn spawn_worker<F, R>(host_core: usize, barrier: Arc<SpinBarrier>, body: F) -> JoinHandle<R>
where
F: FnOnce(Arc<SpinBarrier>) -> R + Send + 'static,
R: Send + 'static,
{
thread::spawn(move || {
pin_to_host_core(host_core);
let b_for_body = barrier.clone();
match std::panic::catch_unwind(AssertUnwindSafe(move || body(b_for_body))) {
Ok(r) => r,
Err(payload) => {
barrier.poison();
std::panic::resume_unwind(payload);
}
}
})
}
pub fn pin_to_host_core(host_core: usize) {
assert!(
host_core < usize::BITS as usize,
"host_core {host_core} exceeds processor-mask bit width"
);
#[cfg(target_os = "windows")]
{
use winapi::um::processthreadsapi::GetCurrentThread;
use winapi::um::winbase::SetThreadAffinityMask;
let h = unsafe { GetCurrentThread() };
let mask = 1usize << host_core;
let prev = unsafe { SetThreadAffinityMask(h, mask) };
assert!(
prev != 0,
"SetThreadAffinityMask failed for host core {host_core}"
);
}
#[cfg(target_os = "linux")]
{
let mut set: libc::cpu_set_t = unsafe { std::mem::zeroed() };
unsafe {
libc::CPU_ZERO(&mut set);
libc::CPU_SET(host_core, &mut set);
}
let rc = unsafe {
libc::pthread_setaffinity_np(
libc::pthread_self(),
std::mem::size_of::<libc::cpu_set_t>(),
&set,
)
};
assert!(
rc == 0,
"pthread_setaffinity_np failed for host core {host_core}: errno={rc}"
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn panic_message_extracts_all_payload_kinds() {
let s_payload: Box<dyn std::any::Any + Send> = Box::new(String::from("string panic"));
assert_eq!(panic_message(Some(&s_payload)), "string panic");
let static_payload: Box<dyn std::any::Any + Send> = Box::new("static-str panic");
assert_eq!(panic_message(Some(&static_payload)), "static-str panic");
let other_payload: Box<dyn std::any::Any + Send> = Box::new(42u64);
assert_eq!(
panic_message(Some(&other_payload)),
"<non-string panic payload>"
);
assert_eq!(panic_message(None), "");
}
}