Skip to main content

inference_core/
error.rs

1//! `InferenceError` — the typed error surface that flows up to the
2//! `RequestActor` regardless of whether the bottleneck was GPU memory,
3//! GIL contention, or remote provider quota (doc §6.2).
4
5use std::time::Duration;
6
7use serde::{Deserialize, Serialize};
8
9use crate::runtime::ProviderKind;
10
11pub type InferenceResult<T> = Result<T, InferenceError>;
12
13#[derive(Debug, Clone, thiserror::Error, Serialize, Deserialize)]
14#[serde(tag = "kind", rename_all = "snake_case")]
15pub enum InferenceError {
16    /// 429 from a remote provider. Worker backs off and retries unless
17    /// `max_retries` is exhausted; then this surfaces to the request.
18    #[error("rate-limited (retry after {retry_after:?})")]
19    RateLimited {
20        provider: ProviderKind,
21        #[serde(with = "duration_opt_ms")]
22        retry_after: Option<Duration>,
23    },
24
25    /// Circuit breaker is open for `(provider, endpoint)`. Fail-fast.
26    #[error("circuit open for {provider:?} until {retry_at_unix_ms} (opened at {opened_at_unix_ms})")]
27    CircuitOpen {
28        provider: ProviderKind,
29        opened_at_unix_ms: u64,
30        retry_at_unix_ms: u64,
31    },
32
33    /// Provider safety filter rejected the input/output. Not retryable.
34    #[error("content filtered: {reason}")]
35    ContentFiltered { reason: String },
36
37    /// Input exceeded the model's context window. Not retryable.
38    #[error("context length exceeded ({tokens} > {max_tokens})")]
39    ContextLengthExceeded { tokens: u32, max_tokens: u32 },
40
41    /// 400 from the provider — caller-side bug.
42    #[error("bad request: {message}")]
43    BadRequest { message: String },
44
45    /// 401 — triggers `RemoteSessionActor::rebuild`.
46    #[error("unauthorized: {message}")]
47    Unauthorized { message: String },
48
49    /// 403 — model/feature access denied.
50    #[error("forbidden: {message}")]
51    Forbidden { message: String },
52
53    /// Mailbox / engine queue full. Upstream decides fallback / 429.
54    #[error("backpressure: {0}")]
55    Backpressure(String),
56
57    /// Spend ceiling reached (doc §12.4).
58    #[error("budget exceeded for `{deployment}`")]
59    BudgetExceeded { deployment: String },
60
61    /// Network blip below the HTTP layer.
62    #[error("network error: {0}")]
63    NetworkError(String),
64
65    /// 5xx from provider. Counts toward circuit breaker.
66    #[error("server error: {status}")]
67    ServerError { status: u16, body: Option<String> },
68
69    /// Request or read timeout.
70    #[error("timeout after {elapsed_ms}ms")]
71    Timeout { elapsed_ms: u64 },
72
73    /// Local CUDA context poisoned (sticky failure). Triggers two-tier
74    /// rebuild on the local `WorkerActor` → `ContextActor` boundary.
75    #[error("CUDA context poisoned: {0}")]
76    CudaContextPoisoned(String),
77
78    /// Catch-all for runtime-internal bugs. Not retryable.
79    #[error("internal: {0}")]
80    Internal(String),
81}
82
83impl InferenceError {
84    pub fn is_retryable(&self) -> bool {
85        matches!(
86            self,
87            InferenceError::RateLimited { .. }
88                | InferenceError::ServerError { .. }
89                | InferenceError::Timeout { .. }
90                | InferenceError::NetworkError(_)
91        )
92    }
93
94    /// Whether this error counts toward the circuit-breaker failure
95    /// budget. 429s and content-filter refusals do not (doc §12.2).
96    pub fn counts_as_circuit_failure(&self) -> bool {
97        matches!(
98            self,
99            InferenceError::ServerError { .. }
100                | InferenceError::Timeout { .. }
101                | InferenceError::NetworkError(_)
102        )
103    }
104}
105
106mod duration_opt_ms {
107    use std::time::Duration;
108
109    use serde::{Deserialize, Deserializer, Serialize, Serializer};
110
111    pub fn serialize<S>(d: &Option<Duration>, s: S) -> Result<S::Ok, S::Error>
112    where
113        S: Serializer,
114    {
115        d.map(|x| x.as_millis() as u64).serialize(s)
116    }
117
118    pub fn deserialize<'de, D>(d: D) -> Result<Option<Duration>, D::Error>
119    where
120        D: Deserializer<'de>,
121    {
122        Option::<u64>::deserialize(d).map(|o| o.map(Duration::from_millis))
123    }
124}