1use std::time::Duration;
6
7use serde::{Deserialize, Serialize};
8
9use crate::runtime::ProviderKind;
10
11pub type InferenceResult<T> = Result<T, InferenceError>;
12
13#[derive(Debug, Clone, thiserror::Error, Serialize, Deserialize)]
14#[serde(tag = "kind", rename_all = "snake_case")]
15pub enum InferenceError {
16 #[error("rate-limited (retry after {retry_after:?})")]
19 RateLimited {
20 provider: ProviderKind,
21 #[serde(with = "duration_opt_ms")]
22 retry_after: Option<Duration>,
23 },
24
25 #[error("circuit open for {provider:?} until {retry_at_unix_ms} (opened at {opened_at_unix_ms})")]
27 CircuitOpen {
28 provider: ProviderKind,
29 opened_at_unix_ms: u64,
30 retry_at_unix_ms: u64,
31 },
32
33 #[error("content filtered: {reason}")]
35 ContentFiltered { reason: String },
36
37 #[error("context length exceeded ({tokens} > {max_tokens})")]
39 ContextLengthExceeded { tokens: u32, max_tokens: u32 },
40
41 #[error("bad request: {message}")]
43 BadRequest { message: String },
44
45 #[error("unauthorized: {message}")]
47 Unauthorized { message: String },
48
49 #[error("forbidden: {message}")]
51 Forbidden { message: String },
52
53 #[error("backpressure: {0}")]
55 Backpressure(String),
56
57 #[error("budget exceeded for `{deployment}`")]
59 BudgetExceeded { deployment: String },
60
61 #[error("network error: {0}")]
63 NetworkError(String),
64
65 #[error("server error: {status}")]
67 ServerError { status: u16, body: Option<String> },
68
69 #[error("timeout after {elapsed_ms}ms")]
71 Timeout { elapsed_ms: u64 },
72
73 #[error("CUDA context poisoned: {0}")]
76 CudaContextPoisoned(String),
77
78 #[error("internal: {0}")]
80 Internal(String),
81}
82
83impl InferenceError {
84 pub fn is_retryable(&self) -> bool {
85 matches!(
86 self,
87 InferenceError::RateLimited { .. }
88 | InferenceError::ServerError { .. }
89 | InferenceError::Timeout { .. }
90 | InferenceError::NetworkError(_)
91 )
92 }
93
94 pub fn counts_as_circuit_failure(&self) -> bool {
97 matches!(
98 self,
99 InferenceError::ServerError { .. }
100 | InferenceError::Timeout { .. }
101 | InferenceError::NetworkError(_)
102 )
103 }
104}
105
106mod duration_opt_ms {
107 use std::time::Duration;
108
109 use serde::{Deserialize, Deserializer, Serialize, Serializer};
110
111 pub fn serialize<S>(d: &Option<Duration>, s: S) -> Result<S::Ok, S::Error>
112 where
113 S: Serializer,
114 {
115 d.map(|x| x.as_millis() as u64).serialize(s)
116 }
117
118 pub fn deserialize<'de, D>(d: D) -> Result<Option<Duration>, D::Error>
119 where
120 D: Deserializer<'de>,
121 {
122 Option::<u64>::deserialize(d).map(|o| o.map(Duration::from_millis))
123 }
124}