vyre_runtime/megakernel/
recovery.rs1use std::sync::Arc;
4
5use vyre_driver::backend::{CompiledPipeline, DispatchConfig, VyreBackend};
6use vyre_driver::BackendError;
7use vyre_foundation::ir::Program;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum MegakernelRecoveryDecision {
12 RecompiledPipeline,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum MegakernelRecoveryClass {
19 DeviceLoss,
21 TransientQueue,
23 ProgramBug,
25 Unclassified,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct MegakernelRecoveryPolicy {
32 pub retry_device_loss_once: bool,
34}
35
36impl Default for MegakernelRecoveryPolicy {
37 fn default() -> Self {
38 Self {
39 retry_device_loss_once: true,
40 }
41 }
42}
43
44#[must_use]
47pub fn backend_error_indicates_device_loss(error: &BackendError) -> bool {
48 classify_backend_recovery_error(error) == MegakernelRecoveryClass::DeviceLoss
49}
50
51#[must_use]
53pub fn classify_backend_recovery_error(error: &BackendError) -> MegakernelRecoveryClass {
54 match error {
55 BackendError::DeviceOutOfMemory { .. } | BackendError::PoisonedLock { .. } => {
56 MegakernelRecoveryClass::TransientQueue
57 }
58 BackendError::KernelCompileFailed { .. }
59 | BackendError::InvalidProgram { .. }
60 | BackendError::UnsupportedFeature { .. } => MegakernelRecoveryClass::ProgramBug,
61 BackendError::DispatchFailed { message, .. } => classify_recovery_message(message),
62 BackendError::Raw(message) => classify_recovery_message(message),
63 _ => classify_recovery_message(&error.to_string()),
64 }
65}
66
67fn classify_recovery_message(message: &str) -> MegakernelRecoveryClass {
68 if text_contains_any_marker(message, DEVICE_LOSS_MARKERS) {
69 return MegakernelRecoveryClass::DeviceLoss;
70 }
71 if text_contains_any_marker(message, TRANSIENT_QUEUE_MARKERS) {
72 return MegakernelRecoveryClass::TransientQueue;
73 }
74 if text_contains_any_marker(message, PROGRAM_BUG_MARKERS) {
75 return MegakernelRecoveryClass::ProgramBug;
76 }
77 MegakernelRecoveryClass::Unclassified
78}
79
80#[must_use]
83pub fn backend_error_message_indicates_device_loss(error: &BackendError) -> bool {
84 let text = error.to_string();
85 text_contains_any_marker(&text, DEVICE_LOSS_MARKERS)
86}
87
88const DEVICE_LOSS_MARKERS: &[&str] = &[
89 "device lost",
90 "devicelost",
91 "context lost",
92 "lost device",
93 "adapter lost",
94 "gpu reset",
95 "device_error_context_is_destroyed",
96 "device_error_context_is_current",
97 "device_error_deinitialized",
98 "stale pipeline",
99];
100
101const TRANSIENT_QUEUE_MARKERS: &[&str] = &[
102 "queue full",
103 "backpressure",
104 "temporarily unavailable",
105 "try again",
106 "would block",
107 "timeout",
108 "timed out",
109 "out of memory",
110 "device out of memory",
111];
112
113const PROGRAM_BUG_MARKERS: &[&str] = &[
114 "invalid program",
115 "kernel-source compile failed",
116 "compile failed",
117 "unsupported feature",
118 "validation failed",
119 "lowering failed",
120 "type error",
121];
122
123fn text_contains_any_marker(text: &str, markers: &[&str]) -> bool {
124 markers
125 .iter()
126 .any(|marker| contains_ascii_case_insensitive(text, marker))
127}
128
129fn contains_ascii_case_insensitive(haystack: &str, needle: &str) -> bool {
130 let needle = needle.as_bytes();
131 if needle.is_empty() {
132 return true;
133 }
134 haystack
135 .as_bytes()
136 .windows(needle.len())
137 .any(|window| window.eq_ignore_ascii_case(needle))
138}
139
140pub fn recover_compiled_pipeline(
147 backend: &Arc<dyn VyreBackend>,
148 program: Arc<Program>,
149 config: &DispatchConfig,
150) -> Result<Arc<dyn CompiledPipeline>, BackendError> {
151 vyre_driver::pipeline::compile_shared(Arc::clone(backend), program, config)
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[test]
159 fn recovery_classifier_separates_device_loss_transient_queue_and_program_bug() {
160 let device_loss = BackendError::DispatchFailed {
161 code: None,
162 message: "DeviceLost after queue submit".to_string(),
163 };
164 assert_eq!(
165 classify_backend_recovery_error(&device_loss),
166 MegakernelRecoveryClass::DeviceLoss
167 );
168 assert!(backend_error_indicates_device_loss(&device_loss));
169 assert!(backend_error_message_indicates_device_loss(&device_loss));
170
171 let transient = BackendError::new("queue full during publish. Fix: retry after drain.");
172 assert_eq!(
173 classify_backend_recovery_error(&transient),
174 MegakernelRecoveryClass::TransientQueue
175 );
176 assert!(!backend_error_indicates_device_loss(&transient));
177
178 let program_bug = BackendError::InvalidProgram {
179 fix: "Fix: validate descriptor before backend lowering.".to_string(),
180 };
181 assert_eq!(
182 classify_backend_recovery_error(&program_bug),
183 MegakernelRecoveryClass::ProgramBug
184 );
185 }
186
187 #[test]
188 fn recovery_classifier_prefers_device_loss_over_transient_markers() {
189 let error = BackendError::new(
190 "queue full because stale pipeline hit adapter lost. Fix: rebuild the pipeline.",
191 );
192
193 assert_eq!(
194 classify_backend_recovery_error(&error),
195 MegakernelRecoveryClass::DeviceLoss
196 );
197 }
198
199 #[test]
200 fn recovery_classifier_leaves_unknown_errors_unclassified() {
201 let error =
202 BackendError::new("backend returned vendor code 17. Fix: inspect backend logs.");
203
204 assert_eq!(
205 classify_backend_recovery_error(&error),
206 MegakernelRecoveryClass::Unclassified
207 );
208 }
209}