use std::error::Error;
use reqwest::StatusCode;
use serde::Serialize;
use serde_json::{json, Value};
use thiserror::Error;
use tokio::time::Duration;
use super::{
http_trace::ProviderHttpTraceRequest, ProviderTransportDiagnostics, ReqwestTransportDiagnostics,
};
pub(crate) const PROVIDER_MAX_RETRIES: usize = 2;
const PROVIDER_RETRY_BASE_BACKOFF_MS: u64 = 200;
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub(crate) enum ProviderFailureKind {
Timeout,
Connection,
RateLimited,
ServerError,
AuthError,
ContractError,
InvalidResponse,
UnsupportedTransport,
Unknown,
}
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub(crate) enum RetryDisposition {
Retryable,
FailFast,
}
#[derive(Debug, Clone, Copy, Serialize)]
pub(crate) struct ProviderFailureClassification {
pub kind: ProviderFailureKind,
pub disposition: RetryDisposition,
}
#[derive(Debug, Error)]
#[error("{message}")]
pub(crate) struct ProviderTransportError {
pub classification: ProviderFailureClassification,
pub status: Option<u16>,
pub diagnostics: Option<ProviderTransportDiagnostics>,
message: String,
}
impl ProviderFailureKind {
pub(crate) fn as_str(self) -> &'static str {
match self {
Self::Timeout => "timeout",
Self::Connection => "connection",
Self::RateLimited => "rate_limited",
Self::ServerError => "server_error",
Self::AuthError => "auth_error",
Self::ContractError => "contract_error",
Self::InvalidResponse => "invalid_response",
Self::UnsupportedTransport => "unsupported_transport",
Self::Unknown => "unknown",
}
}
}
impl RetryDisposition {
pub(crate) fn as_str(self) -> &'static str {
match self {
Self::Retryable => "retryable",
Self::FailFast => "fail_fast",
}
}
}
pub(crate) fn provider_retry_policy_json() -> Value {
json!({
"max_retries_per_provider": PROVIDER_MAX_RETRIES,
"max_attempts_per_provider": provider_max_attempts(),
"base_backoff_ms": PROVIDER_RETRY_BASE_BACKOFF_MS,
"retryable_failure_kinds": [
ProviderFailureKind::Timeout.as_str(),
ProviderFailureKind::Connection.as_str(),
ProviderFailureKind::RateLimited.as_str(),
ProviderFailureKind::ServerError.as_str(),
],
"fail_fast_failure_kinds": [
ProviderFailureKind::AuthError.as_str(),
ProviderFailureKind::ContractError.as_str(),
ProviderFailureKind::InvalidResponse.as_str(),
ProviderFailureKind::UnsupportedTransport.as_str(),
ProviderFailureKind::Unknown.as_str(),
],
})
}
pub(crate) fn provider_max_attempts() -> usize {
PROVIDER_MAX_RETRIES + 1
}
pub(crate) fn provider_retry_backoff(attempt: usize) -> Duration {
Duration::from_millis(PROVIDER_RETRY_BASE_BACKOFF_MS * attempt as u64)
}
pub(crate) fn classify_provider_error(error: &anyhow::Error) -> ProviderFailureClassification {
error
.downcast_ref::<ProviderTransportError>()
.map(|error| error.classification)
.unwrap_or(ProviderFailureClassification {
kind: ProviderFailureKind::Unknown,
disposition: RetryDisposition::FailFast,
})
}
pub(crate) fn provider_transport_error(
classification: ProviderFailureClassification,
status: Option<u16>,
diagnostics: Option<ProviderTransportDiagnostics>,
message: impl Into<String>,
) -> anyhow::Error {
ProviderTransportError {
classification,
status,
diagnostics,
message: message.into(),
}
.into()
}
pub(crate) fn classify_reqwest_transport_error_with_trace(
context: &str,
stage: &str,
provider: &str,
model_ref: Option<&str>,
url: Option<&str>,
error: reqwest::Error,
trace: Option<&ProviderHttpTraceRequest>,
) -> anyhow::Error {
let status = error.status().map(|status| status.as_u16());
let source_chain = error_chain_messages(&error);
let classification = classify_reqwest_transport_failure(stage, &error, &source_chain);
provider_transport_error(
classification,
status,
Some(reqwest_transport_diagnostics(
stage,
provider,
model_ref,
url,
&error,
source_chain,
trace,
)),
format!("{context}: {error}"),
)
}
fn classify_reqwest_transport_failure(
stage: &str,
error: &reqwest::Error,
source_chain: &[String],
) -> ProviderFailureClassification {
if error.is_timeout() {
ProviderFailureClassification {
kind: ProviderFailureKind::Timeout,
disposition: RetryDisposition::Retryable,
}
} else if error.is_connect() {
ProviderFailureClassification {
kind: ProviderFailureKind::Connection,
disposition: RetryDisposition::Retryable,
}
} else if is_retryable_request_send_transport_failure(stage, source_chain)
|| is_retryable_response_body_read_interruption(stage, error, source_chain)
{
ProviderFailureClassification {
kind: ProviderFailureKind::Connection,
disposition: RetryDisposition::Retryable,
}
} else {
ProviderFailureClassification {
kind: ProviderFailureKind::Unknown,
disposition: RetryDisposition::FailFast,
}
}
}
fn is_retryable_request_send_transport_failure(stage: &str, source_chain: &[String]) -> bool {
if !matches!(stage, "streaming_request_send") {
return false;
}
source_chain.iter().any(|message| {
let message = message.to_ascii_lowercase();
message.contains("connection error")
|| message.contains("connection closed")
|| message.contains("connection reset")
|| message.contains("connection aborted")
|| message.contains("tls close_notify")
|| message.contains("broken pipe")
})
}
fn is_retryable_response_body_read_interruption(
stage: &str,
error: &reqwest::Error,
source_chain: &[String],
) -> bool {
if !matches!(stage, "response_body" | "streaming_response_body") {
return false;
}
if !(error.is_body() || error.is_decode()) {
return false;
}
source_chain.iter().any(|message| {
let message = message.to_ascii_lowercase();
message.contains("unexpected eof")
|| message.contains("end of file")
|| message.contains("connection reset")
|| message.contains("connection closed")
|| message.contains("connection aborted")
|| message.contains("broken pipe")
|| message.contains("incomplete message")
|| message.contains("error reading a body from connection")
|| message.contains("chunk size")
|| message.contains("request or response body error")
})
}
pub(crate) fn classify_status_error_with_trace(
context: &str,
stage: &str,
provider: Option<&str>,
model_ref: Option<&str>,
url: Option<&str>,
status: StatusCode,
body: String,
trace: Option<&ProviderHttpTraceRequest>,
) -> anyhow::Error {
let classification = match status {
StatusCode::TOO_MANY_REQUESTS => ProviderFailureClassification {
kind: ProviderFailureKind::RateLimited,
disposition: RetryDisposition::Retryable,
},
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => ProviderFailureClassification {
kind: ProviderFailureKind::AuthError,
disposition: RetryDisposition::FailFast,
},
_ if status.is_server_error() => ProviderFailureClassification {
kind: ProviderFailureKind::ServerError,
disposition: RetryDisposition::Retryable,
},
_ if status.is_client_error() => ProviderFailureClassification {
kind: ProviderFailureKind::ContractError,
disposition: RetryDisposition::FailFast,
},
_ => ProviderFailureClassification {
kind: ProviderFailureKind::Unknown,
disposition: RetryDisposition::FailFast,
},
};
provider_transport_error(
classification,
Some(status.as_u16()),
Some(ProviderTransportDiagnostics {
stage: stage.to_string(),
provider: provider.map(ToString::to_string),
model_ref: model_ref.map(ToString::to_string),
url: url.map(sanitize_transport_url),
status: Some(status.as_u16()),
reqwest: None,
http_trace: trace.and_then(|trace| trace.diagnostics(Some(status.as_u16()))),
source_chain: Vec::new(),
}),
format!("{context} with status {status}: {body}"),
)
}
pub(crate) fn invalid_response_error(
context: &str,
error: impl std::fmt::Display,
) -> anyhow::Error {
provider_transport_error(
ProviderFailureClassification {
kind: ProviderFailureKind::InvalidResponse,
disposition: RetryDisposition::FailFast,
},
None,
None,
format!("{context}: {error}"),
)
}
pub(crate) fn timeout_transport_error_with_trace(
context: &str,
stage: &str,
provider: &str,
model_ref: Option<&str>,
url: Option<&str>,
reason: impl Into<String>,
trace: Option<&ProviderHttpTraceRequest>,
) -> anyhow::Error {
provider_transport_error(
ProviderFailureClassification {
kind: ProviderFailureKind::Timeout,
disposition: RetryDisposition::Retryable,
},
None,
Some(ProviderTransportDiagnostics {
stage: stage.to_string(),
provider: Some(provider.to_string()),
model_ref: model_ref.map(ToString::to_string),
url: url.map(sanitize_transport_url),
status: None,
reqwest: None,
http_trace: trace.and_then(|trace| trace.diagnostics(None)),
source_chain: vec![reason.into()],
}),
context.to_string(),
)
}
pub(crate) fn sanitize_transport_url(raw: &str) -> String {
let Ok(mut url) = reqwest::Url::parse(raw) else {
return raw.to_string();
};
let _ = url.set_username("");
let _ = url.set_password(None);
url.set_query(None);
url.set_fragment(None);
url.to_string()
}
fn reqwest_transport_diagnostics(
stage: &str,
provider: &str,
model_ref: Option<&str>,
url: Option<&str>,
error: &reqwest::Error,
source_chain: Vec<String>,
trace: Option<&ProviderHttpTraceRequest>,
) -> ProviderTransportDiagnostics {
let status = error.status().map(|status| status.as_u16());
ProviderTransportDiagnostics {
stage: stage.to_string(),
provider: Some(provider.to_string()),
model_ref: model_ref.map(ToString::to_string),
url: url
.or_else(|| error.url().map(reqwest::Url::as_str))
.map(sanitize_transport_url),
status,
reqwest: Some(ReqwestTransportDiagnostics {
is_timeout: error.is_timeout(),
is_connect: error.is_connect(),
is_request: error.is_request(),
is_body: error.is_body(),
is_decode: error.is_decode(),
is_redirect: error.is_redirect(),
status,
}),
http_trace: trace.and_then(|trace| trace.diagnostics(status)),
source_chain,
}
}
fn error_chain_messages(error: &reqwest::Error) -> Vec<String> {
let mut chain = Vec::new();
let mut current = error.source();
while let Some(source) = current {
let message = source.to_string();
if !message.trim().is_empty() {
chain.push(message);
}
current = source.source();
}
chain
}
pub(crate) fn format_provider_failure(
model_ref: &str,
attempts: usize,
error: &anyhow::Error,
) -> String {
let classification = classify_provider_error(error);
let status = error
.downcast_ref::<ProviderTransportError>()
.and_then(|error| error.status)
.map(|status| format!(", status={status}"))
.unwrap_or_default();
match classification.disposition {
RetryDisposition::Retryable => format!(
"{model_ref}: retries_exhausted after {attempts} attempts ({kind}{status}): {error}",
kind = classification.kind.as_str()
),
RetryDisposition::FailFast => format!(
"{model_ref}: fail_fast ({kind}{status}): {error}",
kind = classification.kind.as_str()
),
}
}
#[cfg(test)]
mod tests {
#[test]
fn transport_url_sanitizer_removes_credentials_query_and_fragment() {
assert_eq!(
super::sanitize_transport_url(
"https://user:secret@example.com/v1/responses?api_key=token#frag"
),
"https://example.com/v1/responses"
);
}
#[test]
fn streaming_request_send_connection_source_chain_is_retryable() {
let source_chain = vec![
"client error (SendRequest)".to_string(),
"connection error".to_string(),
"peer closed connection without sending TLS close_notify".to_string(),
];
assert!(super::is_retryable_request_send_transport_failure(
"streaming_request_send",
&source_chain
));
}
#[test]
fn request_send_connection_source_chain_is_stage_limited() {
let source_chain = vec!["connection error".to_string()];
assert!(!super::is_retryable_request_send_transport_failure(
"response_status",
&source_chain
));
}
#[test]
fn streaming_request_send_non_transport_source_chain_is_not_retryable() {
let source_chain = vec![
"builder error".to_string(),
"invalid header value".to_string(),
];
assert!(!super::is_retryable_request_send_transport_failure(
"streaming_request_send",
&source_chain
));
}
}