Skip to main content

atomr_infer_core/
error.rs

1//! `InferenceError` — the typed error surface that flows up to the
2//! `RequestActor` regardless of whether the bottleneck was GPU memory,
3//! GIL contention, or remote provider quota (doc §6.2).
4
5use std::time::Duration;
6
7use serde::{Deserialize, Serialize};
8
9use crate::runtime::ProviderKind;
10
11pub type InferenceResult<T> = Result<T, InferenceError>;
12
13#[derive(Debug, Clone, thiserror::Error, Serialize, Deserialize)]
14#[serde(tag = "kind", rename_all = "snake_case")]
15#[non_exhaustive]
16pub enum InferenceError {
17    /// 429 from a remote provider. Worker backs off and retries unless
18    /// `max_retries` is exhausted; then this surfaces to the request.
19    #[error("rate-limited (retry after {retry_after:?})")]
20    RateLimited {
21        provider: ProviderKind,
22        #[serde(with = "duration_opt_ms")]
23        retry_after: Option<Duration>,
24    },
25
26    /// Circuit breaker is open for `(provider, endpoint)`. Fail-fast.
27    #[error("circuit open for {provider:?} until {retry_at_unix_ms} (opened at {opened_at_unix_ms})")]
28    CircuitOpen {
29        provider: ProviderKind,
30        opened_at_unix_ms: u64,
31        retry_at_unix_ms: u64,
32    },
33
34    /// Provider safety filter rejected the input/output. Not retryable.
35    #[error("content filtered: {reason}")]
36    ContentFiltered { reason: String },
37
38    /// Input exceeded the model's context window. Not retryable.
39    #[error("context length exceeded ({tokens} > {max_tokens})")]
40    ContextLengthExceeded { tokens: u32, max_tokens: u32 },
41
42    /// 400 from the provider — caller-side bug.
43    #[error("bad request: {message}")]
44    BadRequest { message: String },
45
46    /// 401 — triggers `RemoteSessionActor::rebuild`.
47    #[error("unauthorized: {message}")]
48    Unauthorized { message: String },
49
50    /// 403 — model/feature access denied.
51    #[error("forbidden: {message}")]
52    Forbidden { message: String },
53
54    /// Mailbox / engine queue full. Upstream decides fallback / 429.
55    #[error("backpressure: {0}")]
56    Backpressure(String),
57
58    /// Spend ceiling reached (doc §12.4).
59    #[error("budget exceeded for `{deployment}`")]
60    BudgetExceeded { deployment: String },
61
62    /// Network blip below the HTTP layer.
63    #[error("network error: {0}")]
64    NetworkError(String),
65
66    /// 5xx from provider. Counts toward circuit breaker.
67    #[error("server error: {status}")]
68    ServerError { status: u16, body: Option<String> },
69
70    /// Request or read timeout.
71    #[error("timeout after {elapsed_ms}ms")]
72    Timeout { elapsed_ms: u64 },
73
74    /// Local CUDA context poisoned (sticky failure). Triggers two-tier
75    /// rebuild on the local `WorkerActor` → `ContextActor` boundary.
76    #[error("CUDA context poisoned: {0}")]
77    CudaContextPoisoned(String),
78
79    /// Catch-all for runtime-internal bugs. Not retryable.
80    #[error("internal: {0}")]
81    Internal(String),
82}
83
84impl InferenceError {
85    pub fn is_retryable(&self) -> bool {
86        matches!(
87            self,
88            InferenceError::RateLimited { .. }
89                | InferenceError::ServerError { .. }
90                | InferenceError::Timeout { .. }
91                | InferenceError::NetworkError(_)
92        )
93    }
94
95    /// Whether this error counts toward the circuit-breaker failure
96    /// budget. 429s and content-filter refusals do not (doc §12.2).
97    pub fn counts_as_circuit_failure(&self) -> bool {
98        matches!(
99            self,
100            InferenceError::ServerError { .. }
101                | InferenceError::Timeout { .. }
102                | InferenceError::NetworkError(_)
103        )
104    }
105}
106
107mod duration_opt_ms {
108    use std::time::Duration;
109
110    use serde::{Deserialize, Deserializer, Serialize, Serializer};
111
112    pub fn serialize<S>(d: &Option<Duration>, s: S) -> Result<S::Ok, S::Error>
113    where
114        S: Serializer,
115    {
116        d.map(|x| x.as_millis() as u64).serialize(s)
117    }
118
119    pub fn deserialize<'de, D>(d: D) -> Result<Option<Duration>, D::Error>
120    where
121        D: Deserializer<'de>,
122    {
123        Option::<u64>::deserialize(d).map(|o| o.map(Duration::from_millis))
124    }
125}