use atomr_core::supervision::{Directive, OneForOneStrategy, SupervisorOf, SupervisorStrategy};
use std::time::Duration;
use thiserror::Error;
pub const CONTEXT_POISONED_TAG: &str = "ContextPoisoned";
pub const OUT_OF_MEMORY_TAG: &str = "OutOfMemory";
pub const UNRECOVERABLE_TAG: &str = "Unrecoverable";
#[derive(Debug, Error)]
pub enum GpuError {
#[error("ContextPoisoned: {0}")]
ContextPoisoned(String),
#[error("OutOfMemory: {0}")]
OutOfMemory(String),
#[error("Unrecoverable: {0}")]
Unrecoverable(String),
#[error("GpuRef stale: {0}")]
GpuRefStale(&'static str),
#[error("cudarc driver error: {0}")]
Driver(String),
#[deprecated(note = "use GpuError::LibraryError { lib: \"cublas\", msg } instead")]
#[error("cudarc cuBLAS error: {0}")]
Cublas(String),
#[error("cudarc {lib} error: {msg}")]
LibraryError { lib: &'static str, msg: String },
#[error("ask timed out before GPU completion")]
Timeout,
}
impl GpuError {
pub fn lib(lib: &'static str, msg: impl Into<String>) -> Self {
Self::LibraryError {
lib,
msg: msg.into(),
}
}
}
impl GpuError {
pub fn panic_message(&self) -> String {
self.to_string()
}
}
pub fn decider() -> impl Fn(&str) -> Directive + Send + Sync + 'static {
|panic_msg: &str| {
if panic_msg.contains(CONTEXT_POISONED_TAG) {
Directive::Restart
} else if panic_msg.contains(OUT_OF_MEMORY_TAG) {
Directive::Resume
} else if panic_msg.contains(UNRECOVERABLE_TAG) {
Directive::Stop
} else {
Directive::Escalate
}
}
}
pub fn device_supervisor_strategy() -> SupervisorStrategy {
OneForOneStrategy::new()
.with_max_retries(3)
.with_within(Duration::from_secs(60))
.with_decider(decider())
.into()
}
pub struct DeviceSupervisor;
impl DeviceSupervisor {
pub fn decide(err: &GpuError) -> Directive {
match err {
GpuError::ContextPoisoned(_) => Directive::Restart,
GpuError::OutOfMemory(_) => Directive::Resume,
GpuError::Unrecoverable(_) => Directive::Stop,
GpuError::Timeout
| GpuError::GpuRefStale(_)
| GpuError::Driver(_)
| GpuError::LibraryError { .. } => Directive::Escalate,
#[allow(deprecated)]
GpuError::Cublas(_) => Directive::Escalate,
}
}
pub fn decide_str(panic_msg: &str) -> Directive {
if panic_msg.contains(CONTEXT_POISONED_TAG) {
Directive::Restart
} else if panic_msg.contains(OUT_OF_MEMORY_TAG) {
Directive::Resume
} else if panic_msg.contains(UNRECOVERABLE_TAG) {
Directive::Stop
} else {
Directive::Escalate
}
}
}
impl<C> SupervisorOf<C> for DeviceSupervisor
where
C: atomr_core::actor::Actor,
{
type ChildError = GpuError;
fn decide(&self, err: &GpuError) -> Directive {
DeviceSupervisor::decide(err)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decider_routes_context_poisoned_to_restart() {
let d = decider();
assert_eq!(d("ContextPoisoned: cuInit failed"), Directive::Restart);
}
#[test]
fn decider_routes_out_of_memory_to_resume() {
let d = decider();
assert_eq!(d("OutOfMemory: alloc 1GB"), Directive::Resume);
}
#[test]
fn decider_routes_unrecoverable_to_stop() {
let d = decider();
assert_eq!(d("Unrecoverable: hardware fault"), Directive::Stop);
}
#[test]
fn decider_escalates_unknown_panics() {
let d = decider();
assert_eq!(d("some other panic"), Directive::Escalate);
}
#[test]
fn typed_supervisor_routes_context_poisoned_to_restart() {
let err = GpuError::ContextPoisoned("simulated".into());
assert_eq!(DeviceSupervisor::decide(&err), Directive::Restart);
}
#[test]
fn typed_supervisor_routes_oom_to_resume() {
let err = GpuError::OutOfMemory("alloc 1GB".into());
assert_eq!(DeviceSupervisor::decide(&err), Directive::Resume);
}
#[test]
fn typed_supervisor_routes_unrecoverable_to_stop() {
let err = GpuError::Unrecoverable("hw fault".into());
assert_eq!(DeviceSupervisor::decide(&err), Directive::Stop);
}
#[test]
fn typed_supervisor_escalates_other() {
let err = GpuError::Timeout;
assert_eq!(DeviceSupervisor::decide(&err), Directive::Escalate);
let err = GpuError::GpuRefStale("stale");
assert_eq!(DeviceSupervisor::decide(&err), Directive::Escalate);
let err = GpuError::lib("cublas", "x");
assert_eq!(DeviceSupervisor::decide(&err), Directive::Escalate);
}
}