Skip to main content

oxibonsai_runtime/
recovery.rs

1//! Error recovery strategies for production resilience.
2//!
3//! Provides retry logic with exponential backoff, error classification,
4//! and memory-aware batch sizing.
5
6use std::time::Duration;
7
8use crate::error::{RuntimeError, RuntimeResult};
9
10/// Recovery strategy for different error types.
11#[derive(Debug, Clone)]
12pub enum RecoveryStrategy {
13    /// Retry the operation with backoff.
14    Retry {
15        /// Maximum number of retry attempts.
16        max_attempts: usize,
17        /// Base delay between retries (doubled each attempt).
18        delay: Duration,
19    },
20    /// Fall back to an alternative approach.
21    Fallback(String),
22    /// Abort — the error is not recoverable.
23    Abort,
24}
25
26impl std::fmt::Display for RecoveryStrategy {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            Self::Retry {
30                max_attempts,
31                delay,
32            } => write!(
33                f,
34                "retry (max {} attempts, base delay {:?})",
35                max_attempts, delay
36            ),
37            Self::Fallback(desc) => write!(f, "fallback: {}", desc),
38            Self::Abort => write!(f, "abort"),
39        }
40    }
41}
42
43/// Determine the appropriate recovery strategy for a given error.
44pub fn recovery_strategy_for(error: &RuntimeError) -> RecoveryStrategy {
45    match error {
46        // IO errors are often transient
47        RuntimeError::Io(_) => RecoveryStrategy::Retry {
48            max_attempts: 3,
49            delay: Duration::from_millis(100),
50        },
51        // Timeout errors should be retried with longer timeout
52        RuntimeError::Timeout { .. } => RecoveryStrategy::Retry {
53            max_attempts: 2,
54            delay: Duration::from_millis(500),
55        },
56        // Capacity errors: wait and retry
57        RuntimeError::CapacityExhausted { .. } => RecoveryStrategy::Retry {
58            max_attempts: 3,
59            delay: Duration::from_millis(200),
60        },
61        // Circuit open: wait for recovery
62        RuntimeError::CircuitOpen => RecoveryStrategy::Retry {
63            max_attempts: 1,
64            delay: Duration::from_secs(5),
65        },
66        // Config errors are permanent
67        RuntimeError::Config(_) => RecoveryStrategy::Abort,
68        // File not found is permanent
69        RuntimeError::FileNotFound { .. } => RecoveryStrategy::Abort,
70        // Tokenizer errors may benefit from fallback
71        RuntimeError::Tokenizer(_) => RecoveryStrategy::Fallback("use raw token IDs".to_string()),
72        // Generation stopped is not really an error
73        RuntimeError::GenerationStopped { .. } => RecoveryStrategy::Abort,
74        // Server errors may be transient
75        RuntimeError::Server(_) => RecoveryStrategy::Retry {
76            max_attempts: 2,
77            delay: Duration::from_millis(200),
78        },
79        // Core/kernel/model errors are generally permanent
80        RuntimeError::Core(_) => RecoveryStrategy::Abort,
81        RuntimeError::Kernel(_) => RecoveryStrategy::Abort,
82        RuntimeError::Model(_) => RecoveryStrategy::Abort,
83        // Batch errors: check individual errors
84        RuntimeError::BatchError(_) => RecoveryStrategy::Retry {
85            max_attempts: 1,
86            delay: Duration::from_millis(100),
87        },
88    }
89}
90
91/// Retry a fallible operation with exponential backoff.
92///
93/// Calls `f` up to `max_attempts` times. If `f` succeeds, returns the result
94/// immediately. If all attempts fail, returns the last error.
95///
96/// The delay between attempts doubles each time, starting from `base_delay`.
97pub fn retry_with_backoff<F, T>(
98    max_attempts: usize,
99    base_delay: Duration,
100    mut f: F,
101) -> RuntimeResult<T>
102where
103    F: FnMut() -> RuntimeResult<T>,
104{
105    let attempts = max_attempts.max(1);
106    let mut last_error = None;
107    let mut delay = base_delay;
108
109    for attempt in 0..attempts {
110        match f() {
111            Ok(val) => return Ok(val),
112            Err(e) => {
113                tracing::debug!(
114                    attempt = attempt + 1,
115                    max_attempts = attempts,
116                    error = %e,
117                    "retry attempt failed"
118                );
119                last_error = Some(e);
120
121                if attempt + 1 < attempts {
122                    std::thread::sleep(delay);
123                    delay = delay.saturating_mul(2);
124                }
125            }
126        }
127    }
128
129    Err(last_error.unwrap_or_else(|| {
130        RuntimeError::Config("retry_with_backoff called with zero attempts".to_string())
131    }))
132}
133
134/// Execute a closure with a synchronous timeout.
135///
136/// Spawns the closure on a separate thread and waits for it to complete
137/// within the specified duration. Returns a timeout error if it takes too long.
138pub fn with_timeout<F, T>(duration: Duration, f: F) -> RuntimeResult<T>
139where
140    F: FnOnce() -> RuntimeResult<T> + Send + 'static,
141    T: Send + 'static,
142{
143    let (tx, rx) = std::sync::mpsc::channel();
144    std::thread::spawn(move || {
145        let result = f();
146        let _ = tx.send(result);
147    });
148
149    rx.recv_timeout(duration).unwrap_or_else(|e| match e {
150        std::sync::mpsc::RecvTimeoutError::Timeout => Err(RuntimeError::Timeout {
151            operation: "with_timeout".to_string(),
152            duration_ms: duration.as_millis() as u64,
153        }),
154        std::sync::mpsc::RecvTimeoutError::Disconnected => Err(RuntimeError::Server(
155            "timeout worker thread panicked".to_string(),
156        )),
157    })
158}
159
160/// Calculate recommended batch size based on available memory.
161///
162/// Returns the largest batch size that fits within available memory,
163/// capped at `max_batch`.
164pub fn recommended_batch_size(
165    available_memory_bytes: u64,
166    per_request_memory_bytes: u64,
167    max_batch: usize,
168) -> usize {
169    if per_request_memory_bytes == 0 {
170        return max_batch;
171    }
172
173    let fits = (available_memory_bytes / per_request_memory_bytes) as usize;
174    fits.min(max_batch).max(1)
175}
176
177/// Classification of errors for monitoring and alerting.
178#[derive(Debug, Clone, Copy, PartialEq, Eq)]
179pub enum ErrorClass {
180    /// Retry may help (timeout, resource busy, transient IO).
181    Transient,
182    /// Won't recover without user intervention (invalid input, model error).
183    Permanent,
184    /// Memory/capacity related — may recover if load decreases.
185    ResourceExhaustion,
186}
187
188impl std::fmt::Display for ErrorClass {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        match self {
191            Self::Transient => write!(f, "transient"),
192            Self::Permanent => write!(f, "permanent"),
193            Self::ResourceExhaustion => write!(f, "resource_exhaustion"),
194        }
195    }
196}
197
198/// Classify an error for monitoring purposes.
199pub fn classify_error(error: &RuntimeError) -> ErrorClass {
200    match error {
201        RuntimeError::Io(_) => ErrorClass::Transient,
202        RuntimeError::Timeout { .. } => ErrorClass::Transient,
203        RuntimeError::Server(_) => ErrorClass::Transient,
204        RuntimeError::CircuitOpen => ErrorClass::Transient,
205        RuntimeError::CapacityExhausted { .. } => ErrorClass::ResourceExhaustion,
206        RuntimeError::Config(_) => ErrorClass::Permanent,
207        RuntimeError::FileNotFound { .. } => ErrorClass::Permanent,
208        RuntimeError::Tokenizer(_) => ErrorClass::Permanent,
209        RuntimeError::GenerationStopped { .. } => ErrorClass::Permanent,
210        RuntimeError::Core(_) => ErrorClass::Permanent,
211        RuntimeError::Kernel(_) => ErrorClass::Permanent,
212        RuntimeError::Model(_) => ErrorClass::Permanent,
213        RuntimeError::BatchError(errors) => {
214            // If any error is resource exhaustion, classify as such
215            for e in errors {
216                if classify_error(e) == ErrorClass::ResourceExhaustion {
217                    return ErrorClass::ResourceExhaustion;
218                }
219            }
220            // Otherwise check for transient
221            for e in errors {
222                if classify_error(e) == ErrorClass::Transient {
223                    return ErrorClass::Transient;
224                }
225            }
226            ErrorClass::Permanent
227        }
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[test]
236    fn recovery_strategy_io_error() {
237        let error = RuntimeError::Io(std::io::Error::new(
238            std::io::ErrorKind::ConnectionReset,
239            "reset",
240        ));
241        let strategy = recovery_strategy_for(&error);
242        matches!(strategy, RecoveryStrategy::Retry { .. });
243    }
244
245    #[test]
246    fn recovery_strategy_config_error() {
247        let error = RuntimeError::Config("bad config".to_string());
248        let strategy = recovery_strategy_for(&error);
249        assert!(matches!(strategy, RecoveryStrategy::Abort));
250    }
251
252    #[test]
253    fn recovery_strategy_timeout() {
254        let error = RuntimeError::Timeout {
255            operation: "test".to_string(),
256            duration_ms: 1000,
257        };
258        let strategy = recovery_strategy_for(&error);
259        assert!(matches!(strategy, RecoveryStrategy::Retry { .. }));
260    }
261
262    #[test]
263    fn recovery_strategy_capacity() {
264        let error = RuntimeError::CapacityExhausted {
265            resource: "kv_cache".to_string(),
266        };
267        let strategy = recovery_strategy_for(&error);
268        assert!(matches!(strategy, RecoveryStrategy::Retry { .. }));
269    }
270
271    #[test]
272    fn recovery_strategy_tokenizer() {
273        let error = RuntimeError::Tokenizer("bad token".to_string());
274        let strategy = recovery_strategy_for(&error);
275        assert!(matches!(strategy, RecoveryStrategy::Fallback(_)));
276    }
277
278    #[test]
279    fn recovery_strategy_display() {
280        let retry = RecoveryStrategy::Retry {
281            max_attempts: 3,
282            delay: Duration::from_millis(100),
283        };
284        assert!(format!("{}", retry).contains("retry"));
285
286        let fallback = RecoveryStrategy::Fallback("alt".to_string());
287        assert!(format!("{}", fallback).contains("fallback"));
288
289        assert_eq!(format!("{}", RecoveryStrategy::Abort), "abort");
290    }
291
292    #[test]
293    fn retry_succeeds_first_attempt() {
294        let mut count = 0;
295        let result = retry_with_backoff(3, Duration::from_millis(1), || {
296            count += 1;
297            Ok(42)
298        });
299        assert_eq!(result.expect("should succeed"), 42);
300        assert_eq!(count, 1);
301    }
302
303    #[test]
304    fn retry_succeeds_second_attempt() {
305        let mut count = 0;
306        let result = retry_with_backoff(3, Duration::from_millis(1), || {
307            count += 1;
308            if count < 2 {
309                Err(RuntimeError::Server("transient".to_string()))
310            } else {
311                Ok(42)
312            }
313        });
314        assert_eq!(result.expect("should succeed"), 42);
315        assert_eq!(count, 2);
316    }
317
318    #[test]
319    fn retry_exhausts_attempts() {
320        let mut count = 0;
321        let result: RuntimeResult<i32> = retry_with_backoff(3, Duration::from_millis(1), || {
322            count += 1;
323            Err(RuntimeError::Server("fail".to_string()))
324        });
325        assert!(result.is_err());
326        assert_eq!(count, 3);
327    }
328
329    #[test]
330    fn retry_zero_attempts_treated_as_one() {
331        let mut count = 0;
332        let result: RuntimeResult<i32> = retry_with_backoff(0, Duration::from_millis(1), || {
333            count += 1;
334            Ok(99)
335        });
336        assert_eq!(result.expect("should succeed"), 99);
337        assert_eq!(count, 1);
338    }
339
340    #[test]
341    fn with_timeout_success() {
342        let result = with_timeout(Duration::from_secs(5), || Ok(42));
343        assert_eq!(result.expect("should succeed"), 42);
344    }
345
346    #[test]
347    fn with_timeout_expires() {
348        let result: RuntimeResult<i32> = with_timeout(Duration::from_millis(10), || {
349            std::thread::sleep(Duration::from_secs(5));
350            Ok(42)
351        });
352        assert!(result.is_err());
353        let err = result.expect_err("should timeout");
354        assert!(err.to_string().contains("timeout") || err.to_string().contains("Timeout"));
355    }
356
357    #[test]
358    fn batch_size_normal() {
359        assert_eq!(recommended_batch_size(1_000_000, 100_000, 16), 10);
360    }
361
362    #[test]
363    fn batch_size_capped_at_max() {
364        assert_eq!(recommended_batch_size(10_000_000, 100_000, 8), 8);
365    }
366
367    #[test]
368    fn batch_size_minimum_one() {
369        assert_eq!(recommended_batch_size(1, 1_000_000, 16), 1);
370    }
371
372    #[test]
373    fn batch_size_zero_per_request() {
374        assert_eq!(recommended_batch_size(1_000_000, 0, 16), 16);
375    }
376
377    #[test]
378    fn classify_io_error() {
379        let error = RuntimeError::Io(std::io::Error::other("test"));
380        assert_eq!(classify_error(&error), ErrorClass::Transient);
381    }
382
383    #[test]
384    fn classify_config_error() {
385        let error = RuntimeError::Config("bad".to_string());
386        assert_eq!(classify_error(&error), ErrorClass::Permanent);
387    }
388
389    #[test]
390    fn classify_capacity_error() {
391        let error = RuntimeError::CapacityExhausted {
392            resource: "mem".to_string(),
393        };
394        assert_eq!(classify_error(&error), ErrorClass::ResourceExhaustion);
395    }
396
397    #[test]
398    fn classify_timeout_error() {
399        let error = RuntimeError::Timeout {
400            operation: "gen".to_string(),
401            duration_ms: 1000,
402        };
403        assert_eq!(classify_error(&error), ErrorClass::Transient);
404    }
405
406    #[test]
407    fn classify_batch_error_resource() {
408        let error = RuntimeError::BatchError(vec![RuntimeError::CapacityExhausted {
409            resource: "mem".to_string(),
410        }]);
411        assert_eq!(classify_error(&error), ErrorClass::ResourceExhaustion);
412    }
413
414    #[test]
415    fn classify_batch_error_transient() {
416        let error = RuntimeError::BatchError(vec![RuntimeError::Server("err".to_string())]);
417        assert_eq!(classify_error(&error), ErrorClass::Transient);
418    }
419
420    #[test]
421    fn classify_batch_error_permanent() {
422        let error = RuntimeError::BatchError(vec![RuntimeError::Config("bad".to_string())]);
423        assert_eq!(classify_error(&error), ErrorClass::Permanent);
424    }
425
426    #[test]
427    fn error_class_display() {
428        assert_eq!(format!("{}", ErrorClass::Transient), "transient");
429        assert_eq!(format!("{}", ErrorClass::Permanent), "permanent");
430        assert_eq!(
431            format!("{}", ErrorClass::ResourceExhaustion),
432            "resource_exhaustion"
433        );
434    }
435}