atomr_infer_core/
error.rs1use std::time::Duration;
6
7use serde::{Deserialize, Serialize};
8
9use crate::runtime::ProviderKind;
10
11pub type InferenceResult<T> = Result<T, InferenceError>;
12
13#[derive(Debug, Clone, thiserror::Error, Serialize, Deserialize)]
14#[serde(tag = "kind", rename_all = "snake_case")]
15#[non_exhaustive]
16pub enum InferenceError {
17 #[error("rate-limited (retry after {retry_after:?})")]
20 RateLimited {
21 provider: ProviderKind,
22 #[serde(with = "duration_opt_ms")]
23 retry_after: Option<Duration>,
24 },
25
26 #[error("circuit open for {provider:?} until {retry_at_unix_ms} (opened at {opened_at_unix_ms})")]
28 CircuitOpen {
29 provider: ProviderKind,
30 opened_at_unix_ms: u64,
31 retry_at_unix_ms: u64,
32 },
33
34 #[error("content filtered: {reason}")]
36 ContentFiltered { reason: String },
37
38 #[error("context length exceeded ({tokens} > {max_tokens})")]
40 ContextLengthExceeded { tokens: u32, max_tokens: u32 },
41
42 #[error("bad request: {message}")]
44 BadRequest { message: String },
45
46 #[error("unauthorized: {message}")]
48 Unauthorized { message: String },
49
50 #[error("forbidden: {message}")]
52 Forbidden { message: String },
53
54 #[error("backpressure: {0}")]
56 Backpressure(String),
57
58 #[error("budget exceeded for `{deployment}`")]
60 BudgetExceeded { deployment: String },
61
62 #[error("network error: {0}")]
64 NetworkError(String),
65
66 #[error("server error: {status}")]
68 ServerError { status: u16, body: Option<String> },
69
70 #[error("timeout after {elapsed_ms}ms")]
72 Timeout { elapsed_ms: u64 },
73
74 #[error("CUDA context poisoned: {0}")]
77 CudaContextPoisoned(String),
78
79 #[error("internal: {0}")]
81 Internal(String),
82}
83
84impl InferenceError {
85 pub fn is_retryable(&self) -> bool {
86 matches!(
87 self,
88 InferenceError::RateLimited { .. }
89 | InferenceError::ServerError { .. }
90 | InferenceError::Timeout { .. }
91 | InferenceError::NetworkError(_)
92 )
93 }
94
95 pub fn counts_as_circuit_failure(&self) -> bool {
98 matches!(
99 self,
100 InferenceError::ServerError { .. }
101 | InferenceError::Timeout { .. }
102 | InferenceError::NetworkError(_)
103 )
104 }
105}
106
107mod duration_opt_ms {
108 use std::time::Duration;
109
110 use serde::{Deserialize, Deserializer, Serialize, Serializer};
111
112 pub fn serialize<S>(d: &Option<Duration>, s: S) -> Result<S::Ok, S::Error>
113 where
114 S: Serializer,
115 {
116 d.map(|x| x.as_millis() as u64).serialize(s)
117 }
118
119 pub fn deserialize<'de, D>(d: D) -> Result<Option<Duration>, D::Error>
120 where
121 D: Deserializer<'de>,
122 {
123 Option::<u64>::deserialize(d).map(|o| o.map(Duration::from_millis))
124 }
125}