1use std::time::Duration;
7
8use crate::error::{RuntimeError, RuntimeResult};
9
10#[derive(Debug, Clone)]
12pub enum RecoveryStrategy {
13 Retry {
15 max_attempts: usize,
17 delay: Duration,
19 },
20 Fallback(String),
22 Abort,
24}
25
26impl std::fmt::Display for RecoveryStrategy {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 Self::Retry {
30 max_attempts,
31 delay,
32 } => write!(
33 f,
34 "retry (max {} attempts, base delay {:?})",
35 max_attempts, delay
36 ),
37 Self::Fallback(desc) => write!(f, "fallback: {}", desc),
38 Self::Abort => write!(f, "abort"),
39 }
40 }
41}
42
43pub fn recovery_strategy_for(error: &RuntimeError) -> RecoveryStrategy {
45 match error {
46 RuntimeError::Io(_) => RecoveryStrategy::Retry {
48 max_attempts: 3,
49 delay: Duration::from_millis(100),
50 },
51 RuntimeError::Timeout { .. } => RecoveryStrategy::Retry {
53 max_attempts: 2,
54 delay: Duration::from_millis(500),
55 },
56 RuntimeError::CapacityExhausted { .. } => RecoveryStrategy::Retry {
58 max_attempts: 3,
59 delay: Duration::from_millis(200),
60 },
61 RuntimeError::CircuitOpen => RecoveryStrategy::Retry {
63 max_attempts: 1,
64 delay: Duration::from_secs(5),
65 },
66 RuntimeError::Config(_) => RecoveryStrategy::Abort,
68 RuntimeError::FileNotFound { .. } => RecoveryStrategy::Abort,
70 RuntimeError::Tokenizer(_) => RecoveryStrategy::Fallback("use raw token IDs".to_string()),
72 RuntimeError::GenerationStopped { .. } => RecoveryStrategy::Abort,
74 RuntimeError::Server(_) => RecoveryStrategy::Retry {
76 max_attempts: 2,
77 delay: Duration::from_millis(200),
78 },
79 RuntimeError::Core(_) => RecoveryStrategy::Abort,
81 RuntimeError::Kernel(_) => RecoveryStrategy::Abort,
82 RuntimeError::Model(_) => RecoveryStrategy::Abort,
83 RuntimeError::BatchError(_) => RecoveryStrategy::Retry {
85 max_attempts: 1,
86 delay: Duration::from_millis(100),
87 },
88 }
89}
90
91pub fn retry_with_backoff<F, T>(
98 max_attempts: usize,
99 base_delay: Duration,
100 mut f: F,
101) -> RuntimeResult<T>
102where
103 F: FnMut() -> RuntimeResult<T>,
104{
105 let attempts = max_attempts.max(1);
106 let mut last_error = None;
107 let mut delay = base_delay;
108
109 for attempt in 0..attempts {
110 match f() {
111 Ok(val) => return Ok(val),
112 Err(e) => {
113 tracing::debug!(
114 attempt = attempt + 1,
115 max_attempts = attempts,
116 error = %e,
117 "retry attempt failed"
118 );
119 last_error = Some(e);
120
121 if attempt + 1 < attempts {
122 std::thread::sleep(delay);
123 delay = delay.saturating_mul(2);
124 }
125 }
126 }
127 }
128
129 Err(last_error.unwrap_or_else(|| {
130 RuntimeError::Config("retry_with_backoff called with zero attempts".to_string())
131 }))
132}
133
134pub fn with_timeout<F, T>(duration: Duration, f: F) -> RuntimeResult<T>
139where
140 F: FnOnce() -> RuntimeResult<T> + Send + 'static,
141 T: Send + 'static,
142{
143 let (tx, rx) = std::sync::mpsc::channel();
144 std::thread::spawn(move || {
145 let result = f();
146 let _ = tx.send(result);
147 });
148
149 rx.recv_timeout(duration).unwrap_or_else(|e| match e {
150 std::sync::mpsc::RecvTimeoutError::Timeout => Err(RuntimeError::Timeout {
151 operation: "with_timeout".to_string(),
152 duration_ms: duration.as_millis() as u64,
153 }),
154 std::sync::mpsc::RecvTimeoutError::Disconnected => Err(RuntimeError::Server(
155 "timeout worker thread panicked".to_string(),
156 )),
157 })
158}
159
160pub fn recommended_batch_size(
165 available_memory_bytes: u64,
166 per_request_memory_bytes: u64,
167 max_batch: usize,
168) -> usize {
169 if per_request_memory_bytes == 0 {
170 return max_batch;
171 }
172
173 let fits = (available_memory_bytes / per_request_memory_bytes) as usize;
174 fits.min(max_batch).max(1)
175}
176
177#[derive(Debug, Clone, Copy, PartialEq, Eq)]
179pub enum ErrorClass {
180 Transient,
182 Permanent,
184 ResourceExhaustion,
186}
187
188impl std::fmt::Display for ErrorClass {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 match self {
191 Self::Transient => write!(f, "transient"),
192 Self::Permanent => write!(f, "permanent"),
193 Self::ResourceExhaustion => write!(f, "resource_exhaustion"),
194 }
195 }
196}
197
198pub fn classify_error(error: &RuntimeError) -> ErrorClass {
200 match error {
201 RuntimeError::Io(_) => ErrorClass::Transient,
202 RuntimeError::Timeout { .. } => ErrorClass::Transient,
203 RuntimeError::Server(_) => ErrorClass::Transient,
204 RuntimeError::CircuitOpen => ErrorClass::Transient,
205 RuntimeError::CapacityExhausted { .. } => ErrorClass::ResourceExhaustion,
206 RuntimeError::Config(_) => ErrorClass::Permanent,
207 RuntimeError::FileNotFound { .. } => ErrorClass::Permanent,
208 RuntimeError::Tokenizer(_) => ErrorClass::Permanent,
209 RuntimeError::GenerationStopped { .. } => ErrorClass::Permanent,
210 RuntimeError::Core(_) => ErrorClass::Permanent,
211 RuntimeError::Kernel(_) => ErrorClass::Permanent,
212 RuntimeError::Model(_) => ErrorClass::Permanent,
213 RuntimeError::BatchError(errors) => {
214 for e in errors {
216 if classify_error(e) == ErrorClass::ResourceExhaustion {
217 return ErrorClass::ResourceExhaustion;
218 }
219 }
220 for e in errors {
222 if classify_error(e) == ErrorClass::Transient {
223 return ErrorClass::Transient;
224 }
225 }
226 ErrorClass::Permanent
227 }
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[test]
236 fn recovery_strategy_io_error() {
237 let error = RuntimeError::Io(std::io::Error::new(
238 std::io::ErrorKind::ConnectionReset,
239 "reset",
240 ));
241 let strategy = recovery_strategy_for(&error);
242 matches!(strategy, RecoveryStrategy::Retry { .. });
243 }
244
245 #[test]
246 fn recovery_strategy_config_error() {
247 let error = RuntimeError::Config("bad config".to_string());
248 let strategy = recovery_strategy_for(&error);
249 assert!(matches!(strategy, RecoveryStrategy::Abort));
250 }
251
252 #[test]
253 fn recovery_strategy_timeout() {
254 let error = RuntimeError::Timeout {
255 operation: "test".to_string(),
256 duration_ms: 1000,
257 };
258 let strategy = recovery_strategy_for(&error);
259 assert!(matches!(strategy, RecoveryStrategy::Retry { .. }));
260 }
261
262 #[test]
263 fn recovery_strategy_capacity() {
264 let error = RuntimeError::CapacityExhausted {
265 resource: "kv_cache".to_string(),
266 };
267 let strategy = recovery_strategy_for(&error);
268 assert!(matches!(strategy, RecoveryStrategy::Retry { .. }));
269 }
270
271 #[test]
272 fn recovery_strategy_tokenizer() {
273 let error = RuntimeError::Tokenizer("bad token".to_string());
274 let strategy = recovery_strategy_for(&error);
275 assert!(matches!(strategy, RecoveryStrategy::Fallback(_)));
276 }
277
278 #[test]
279 fn recovery_strategy_display() {
280 let retry = RecoveryStrategy::Retry {
281 max_attempts: 3,
282 delay: Duration::from_millis(100),
283 };
284 assert!(format!("{}", retry).contains("retry"));
285
286 let fallback = RecoveryStrategy::Fallback("alt".to_string());
287 assert!(format!("{}", fallback).contains("fallback"));
288
289 assert_eq!(format!("{}", RecoveryStrategy::Abort), "abort");
290 }
291
292 #[test]
293 fn retry_succeeds_first_attempt() {
294 let mut count = 0;
295 let result = retry_with_backoff(3, Duration::from_millis(1), || {
296 count += 1;
297 Ok(42)
298 });
299 assert_eq!(result.expect("should succeed"), 42);
300 assert_eq!(count, 1);
301 }
302
303 #[test]
304 fn retry_succeeds_second_attempt() {
305 let mut count = 0;
306 let result = retry_with_backoff(3, Duration::from_millis(1), || {
307 count += 1;
308 if count < 2 {
309 Err(RuntimeError::Server("transient".to_string()))
310 } else {
311 Ok(42)
312 }
313 });
314 assert_eq!(result.expect("should succeed"), 42);
315 assert_eq!(count, 2);
316 }
317
318 #[test]
319 fn retry_exhausts_attempts() {
320 let mut count = 0;
321 let result: RuntimeResult<i32> = retry_with_backoff(3, Duration::from_millis(1), || {
322 count += 1;
323 Err(RuntimeError::Server("fail".to_string()))
324 });
325 assert!(result.is_err());
326 assert_eq!(count, 3);
327 }
328
329 #[test]
330 fn retry_zero_attempts_treated_as_one() {
331 let mut count = 0;
332 let result: RuntimeResult<i32> = retry_with_backoff(0, Duration::from_millis(1), || {
333 count += 1;
334 Ok(99)
335 });
336 assert_eq!(result.expect("should succeed"), 99);
337 assert_eq!(count, 1);
338 }
339
340 #[test]
341 fn with_timeout_success() {
342 let result = with_timeout(Duration::from_secs(5), || Ok(42));
343 assert_eq!(result.expect("should succeed"), 42);
344 }
345
346 #[test]
347 fn with_timeout_expires() {
348 let result: RuntimeResult<i32> = with_timeout(Duration::from_millis(10), || {
349 std::thread::sleep(Duration::from_secs(5));
350 Ok(42)
351 });
352 assert!(result.is_err());
353 let err = result.expect_err("should timeout");
354 assert!(err.to_string().contains("timeout") || err.to_string().contains("Timeout"));
355 }
356
357 #[test]
358 fn batch_size_normal() {
359 assert_eq!(recommended_batch_size(1_000_000, 100_000, 16), 10);
360 }
361
362 #[test]
363 fn batch_size_capped_at_max() {
364 assert_eq!(recommended_batch_size(10_000_000, 100_000, 8), 8);
365 }
366
367 #[test]
368 fn batch_size_minimum_one() {
369 assert_eq!(recommended_batch_size(1, 1_000_000, 16), 1);
370 }
371
372 #[test]
373 fn batch_size_zero_per_request() {
374 assert_eq!(recommended_batch_size(1_000_000, 0, 16), 16);
375 }
376
377 #[test]
378 fn classify_io_error() {
379 let error = RuntimeError::Io(std::io::Error::other("test"));
380 assert_eq!(classify_error(&error), ErrorClass::Transient);
381 }
382
383 #[test]
384 fn classify_config_error() {
385 let error = RuntimeError::Config("bad".to_string());
386 assert_eq!(classify_error(&error), ErrorClass::Permanent);
387 }
388
389 #[test]
390 fn classify_capacity_error() {
391 let error = RuntimeError::CapacityExhausted {
392 resource: "mem".to_string(),
393 };
394 assert_eq!(classify_error(&error), ErrorClass::ResourceExhaustion);
395 }
396
397 #[test]
398 fn classify_timeout_error() {
399 let error = RuntimeError::Timeout {
400 operation: "gen".to_string(),
401 duration_ms: 1000,
402 };
403 assert_eq!(classify_error(&error), ErrorClass::Transient);
404 }
405
406 #[test]
407 fn classify_batch_error_resource() {
408 let error = RuntimeError::BatchError(vec![RuntimeError::CapacityExhausted {
409 resource: "mem".to_string(),
410 }]);
411 assert_eq!(classify_error(&error), ErrorClass::ResourceExhaustion);
412 }
413
414 #[test]
415 fn classify_batch_error_transient() {
416 let error = RuntimeError::BatchError(vec![RuntimeError::Server("err".to_string())]);
417 assert_eq!(classify_error(&error), ErrorClass::Transient);
418 }
419
420 #[test]
421 fn classify_batch_error_permanent() {
422 let error = RuntimeError::BatchError(vec![RuntimeError::Config("bad".to_string())]);
423 assert_eq!(classify_error(&error), ErrorClass::Permanent);
424 }
425
426 #[test]
427 fn error_class_display() {
428 assert_eq!(format!("{}", ErrorClass::Transient), "transient");
429 assert_eq!(format!("{}", ErrorClass::Permanent), "permanent");
430 assert_eq!(
431 format!("{}", ErrorClass::ResourceExhaustion),
432 "resource_exhaustion"
433 );
434 }
435}