use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::runtime::ProviderKind;
pub type InferenceResult<T> = Result<T, InferenceError>;
#[derive(Debug, Clone, thiserror::Error, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
#[non_exhaustive]
pub enum InferenceError {
#[error("rate-limited (retry after {retry_after:?})")]
RateLimited {
provider: ProviderKind,
#[serde(with = "duration_opt_ms")]
retry_after: Option<Duration>,
},
#[error("circuit open for {provider:?} until {retry_at_unix_ms} (opened at {opened_at_unix_ms})")]
CircuitOpen {
provider: ProviderKind,
opened_at_unix_ms: u64,
retry_at_unix_ms: u64,
},
#[error("content filtered: {reason}")]
ContentFiltered { reason: String },
#[error("context length exceeded ({tokens} > {max_tokens})")]
ContextLengthExceeded { tokens: u32, max_tokens: u32 },
#[error("bad request: {message}")]
BadRequest { message: String },
#[error("unauthorized: {message}")]
Unauthorized { message: String },
#[error("forbidden: {message}")]
Forbidden { message: String },
#[error("backpressure: {0}")]
Backpressure(String),
#[error("budget exceeded for `{deployment}`")]
BudgetExceeded { deployment: String },
#[error("network error: {0}")]
NetworkError(String),
#[error("server error: {status}")]
ServerError { status: u16, body: Option<String> },
#[error("timeout after {elapsed_ms}ms")]
Timeout { elapsed_ms: u64 },
#[error("CUDA context poisoned: {0}")]
CudaContextPoisoned(String),
#[error("internal: {0}")]
Internal(String),
}
impl InferenceError {
pub fn is_retryable(&self) -> bool {
matches!(
self,
InferenceError::RateLimited { .. }
| InferenceError::ServerError { .. }
| InferenceError::Timeout { .. }
| InferenceError::NetworkError(_)
)
}
pub fn counts_as_circuit_failure(&self) -> bool {
matches!(
self,
InferenceError::ServerError { .. }
| InferenceError::Timeout { .. }
| InferenceError::NetworkError(_)
)
}
}
mod duration_opt_ms {
use std::time::Duration;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S>(d: &Option<Duration>, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
d.map(|x| x.as_millis() as u64).serialize(s)
}
pub fn deserialize<'de, D>(d: D) -> Result<Option<Duration>, D::Error>
where
D: Deserializer<'de>,
{
Option::<u64>::deserialize(d).map(|o| o.map(Duration::from_millis))
}
}