Skip to main content

vyre_runtime/megakernel/
recovery.rs

1//! Device-loss classification and persistent-pipeline rebuild policy.
2
3use std::sync::Arc;
4
5use vyre_driver::backend::{CompiledPipeline, DispatchConfig, VyreBackend};
6use vyre_driver::BackendError;
7use vyre_foundation::ir::Program;
8
9/// Recovery action taken after a backend device-loss symptom.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum MegakernelRecoveryDecision {
12    /// The runtime rebuilt the compiled pipeline on the same backend.
13    RecompiledPipeline,
14}
15
16/// Coarse failure class used by persistent megakernel recovery.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum MegakernelRecoveryClass {
19    /// Backend context, adapter, or compiled-pipeline state was lost or stale.
20    DeviceLoss,
21    /// Queue/resource pressure that can be retried without recompilation.
22    TransientQueue,
23    /// Program/lowering/kernel-source failure that should not be retried as-is.
24    ProgramBug,
25    /// No safe automated recovery class could be inferred.
26    Unclassified,
27}
28
29/// Runtime recovery policy for persistent megakernel dispatch.
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct MegakernelRecoveryPolicy {
32    /// Retry a dispatch once after a device-loss-like backend error.
33    pub retry_device_loss_once: bool,
34}
35
36impl Default for MegakernelRecoveryPolicy {
37    fn default() -> Self {
38        Self {
39            retry_device_loss_once: true,
40        }
41    }
42}
43
44/// Return true when a backend error is consistent with device loss or a stale
45/// compiled pipeline.
46#[must_use]
47pub fn backend_error_indicates_device_loss(error: &BackendError) -> bool {
48    classify_backend_recovery_error(error) == MegakernelRecoveryClass::DeviceLoss
49}
50
51/// Classify a backend failure for persistent megakernel recovery.
52#[must_use]
53pub fn classify_backend_recovery_error(error: &BackendError) -> MegakernelRecoveryClass {
54    match error {
55        BackendError::DeviceOutOfMemory { .. } | BackendError::PoisonedLock { .. } => {
56            MegakernelRecoveryClass::TransientQueue
57        }
58        BackendError::KernelCompileFailed { .. }
59        | BackendError::InvalidProgram { .. }
60        | BackendError::UnsupportedFeature { .. } => MegakernelRecoveryClass::ProgramBug,
61        BackendError::DispatchFailed { message, .. } => classify_recovery_message(message),
62        BackendError::Raw(message) => classify_recovery_message(message),
63        _ => classify_recovery_message(&error.to_string()),
64    }
65}
66
67fn classify_recovery_message(message: &str) -> MegakernelRecoveryClass {
68    if text_contains_any_marker(message, DEVICE_LOSS_MARKERS) {
69        return MegakernelRecoveryClass::DeviceLoss;
70    }
71    if text_contains_any_marker(message, TRANSIENT_QUEUE_MARKERS) {
72        return MegakernelRecoveryClass::TransientQueue;
73    }
74    if text_contains_any_marker(message, PROGRAM_BUG_MARKERS) {
75        return MegakernelRecoveryClass::ProgramBug;
76    }
77    MegakernelRecoveryClass::Unclassified
78}
79
80/// Return true when a backend error is consistent with device loss or a stale
81/// compiled pipeline.
82#[must_use]
83pub fn backend_error_message_indicates_device_loss(error: &BackendError) -> bool {
84    let text = error.to_string();
85    text_contains_any_marker(&text, DEVICE_LOSS_MARKERS)
86}
87
88const DEVICE_LOSS_MARKERS: &[&str] = &[
89    "device lost",
90    "devicelost",
91    "context lost",
92    "lost device",
93    "adapter lost",
94    "gpu reset",
95    "device_error_context_is_destroyed",
96    "device_error_context_is_current",
97    "device_error_deinitialized",
98    "stale pipeline",
99];
100
101const TRANSIENT_QUEUE_MARKERS: &[&str] = &[
102    "queue full",
103    "backpressure",
104    "temporarily unavailable",
105    "try again",
106    "would block",
107    "timeout",
108    "timed out",
109    "out of memory",
110    "device out of memory",
111];
112
113const PROGRAM_BUG_MARKERS: &[&str] = &[
114    "invalid program",
115    "kernel-source compile failed",
116    "compile failed",
117    "unsupported feature",
118    "validation failed",
119    "lowering failed",
120    "type error",
121];
122
123fn text_contains_any_marker(text: &str, markers: &[&str]) -> bool {
124    markers
125        .iter()
126        .any(|marker| contains_ascii_case_insensitive(text, marker))
127}
128
129fn contains_ascii_case_insensitive(haystack: &str, needle: &str) -> bool {
130    let needle = needle.as_bytes();
131    if needle.is_empty() {
132        return true;
133    }
134    haystack
135        .as_bytes()
136        .windows(needle.len())
137        .any(|window| window.eq_ignore_ascii_case(needle))
138}
139
140/// Recompile a persistent megakernel pipeline after a recoverable device
141/// failure.
142///
143/// # Errors
144///
145/// Returns the backend compile error if the backend cannot rebuild the program.
146pub fn recover_compiled_pipeline(
147    backend: &Arc<dyn VyreBackend>,
148    program: Arc<Program>,
149    config: &DispatchConfig,
150) -> Result<Arc<dyn CompiledPipeline>, BackendError> {
151    vyre_driver::pipeline::compile_shared(Arc::clone(backend), program, config)
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn recovery_classifier_separates_device_loss_transient_queue_and_program_bug() {
160        let device_loss = BackendError::DispatchFailed {
161            code: None,
162            message: "DeviceLost after queue submit".to_string(),
163        };
164        assert_eq!(
165            classify_backend_recovery_error(&device_loss),
166            MegakernelRecoveryClass::DeviceLoss
167        );
168        assert!(backend_error_indicates_device_loss(&device_loss));
169        assert!(backend_error_message_indicates_device_loss(&device_loss));
170
171        let transient = BackendError::new("queue full during publish. Fix: retry after drain.");
172        assert_eq!(
173            classify_backend_recovery_error(&transient),
174            MegakernelRecoveryClass::TransientQueue
175        );
176        assert!(!backend_error_indicates_device_loss(&transient));
177
178        let program_bug = BackendError::InvalidProgram {
179            fix: "Fix: validate descriptor before backend lowering.".to_string(),
180        };
181        assert_eq!(
182            classify_backend_recovery_error(&program_bug),
183            MegakernelRecoveryClass::ProgramBug
184        );
185    }
186
187    #[test]
188    fn recovery_classifier_prefers_device_loss_over_transient_markers() {
189        let error = BackendError::new(
190            "queue full because stale pipeline hit adapter lost. Fix: rebuild the pipeline.",
191        );
192
193        assert_eq!(
194            classify_backend_recovery_error(&error),
195            MegakernelRecoveryClass::DeviceLoss
196        );
197    }
198
199    #[test]
200    fn recovery_classifier_leaves_unknown_errors_unclassified() {
201        let error =
202            BackendError::new("backend returned vendor code 17. Fix: inspect backend logs.");
203
204        assert_eq!(
205            classify_backend_recovery_error(&error),
206            MegakernelRecoveryClass::Unclassified
207        );
208    }
209}