use std::sync::atomic::{AtomicU32, Ordering};
use cudarc::driver::{sys, DriverError};
use log::warn;
const ERR_STACK_SIZE: usize = 16;
#[derive(Debug)]
pub struct ErrorSink {
err_state: [AtomicU32; ERR_STACK_SIZE],
err_stack: AtomicU32,
}
impl ErrorSink {
pub fn new() -> Self {
Self {
err_state: [const { AtomicU32::new(0) }; ERR_STACK_SIZE],
err_stack: AtomicU32::new(0),
}
}
pub fn record_err<T>(&self, result: Result<T, DriverError>) {
if let Err(err) = result {
let idx = self.err_stack.fetch_add(1, Ordering::Relaxed) as usize;
if idx < ERR_STACK_SIZE {
self.err_state[idx].store(err.0 as u32, Ordering::Relaxed)
}
}
}
pub fn check_err(&self) -> Result<(), DriverError> {
if let Ok(mut idx) = self.err_stack.fetch_update(
Ordering::Relaxed,
Ordering::Relaxed,
|old| if old > 0 {
Some(u32::min(old, ERR_STACK_SIZE as u32) - 1)
} else {
None
},
) {
if idx >= ERR_STACK_SIZE as u32 {
warn!("error stack overflow detected");
idx = ERR_STACK_SIZE as u32 - 1;
}
let err_state = self.err_state[idx as usize].swap(0, Ordering::Relaxed);
Err(DriverError(unsafe {
std::mem::transmute::<u32, sys::cudaError_enum>(err_state)
}))
} else {
Ok(())
}
}
}