use caliban_provider::Error as ProviderError;
#[derive(thiserror::Error, Debug)]
pub enum GoogleError {
#[error("HTTP request failed: {0}")]
Http(#[from] reqwest::Error),
#[error("response status {status}: {body}")]
BadStatus {
status: u16,
body: String,
},
#[error("deserialize error: {0}")]
Deserialize(#[from] serde_json::Error),
#[error("stream parse error: {0}")]
StreamParse(String),
#[error("upstream error: {0}")]
UpstreamError(String),
#[error("missing config field: {0}")]
MissingConfig(&'static str),
#[error("transport error: {0}")]
Transport(Box<dyn std::error::Error + Send + Sync>),
#[error("invalid request: {0}")]
InvalidRequest(String),
}
impl From<GoogleError> for ProviderError {
fn from(e: GoogleError) -> Self {
use caliban_provider::TransportErrorClass;
match e {
GoogleError::Http(ref err) => {
match caliban_provider::classify_reqwest_error("google", err) {
TransportErrorClass::StreamInterrupted => ProviderError::stream_interrupted(
caliban_provider::render_source_chain(err),
),
TransportErrorClass::Network => ProviderError::network(e),
TransportErrorClass::Adapter => ProviderError::adapter(e),
}
}
GoogleError::BadStatus { status, ref body } => match status {
401 | 403 => ProviderError::Auth(body.clone()),
429 => ProviderError::RateLimit { retry_after: None },
400 => classify_context_length_exceeded(body)
.or_else(|| classify_upstream_server_fault(body))
.unwrap_or_else(|| ProviderError::InvalidRequest(body.clone())),
404 => ProviderError::ModelUnavailable(body.clone()),
_ if status >= 500 => ProviderError::ServerError {
status,
body: body.clone(),
},
_ => ProviderError::adapter(e),
},
GoogleError::InvalidRequest(ref msg) => ProviderError::InvalidRequest(msg.clone()),
GoogleError::UpstreamError(ref msg) => classify_context_length_exceeded(msg)
.or_else(|| classify_upstream_server_fault(msg))
.unwrap_or_else(|| ProviderError::InvalidRequest(msg.clone())),
GoogleError::Deserialize(_)
| GoogleError::StreamParse(_)
| GoogleError::MissingConfig(_)
| GoogleError::Transport(_) => ProviderError::adapter(e),
}
}
}
fn classify_context_length_exceeded(body: &str) -> Option<ProviderError> {
let is_context_error = body.contains("exceeds the maximum number of tokens")
|| body.contains("input token count")
|| body.contains("context_length_exceeded")
|| body.contains("Input tokens exceed")
|| body.contains("Please reduce the length")
|| body.contains("context window")
|| body.contains("maximum context length");
if !is_context_error {
return None;
}
let (max_tokens, requested_tokens) = parse_token_counts(body);
Some(ProviderError::ContextTooLong {
max_tokens,
requested_tokens,
})
}
fn classify_upstream_server_fault(body: &str) -> Option<ProviderError> {
let is_fault = body.contains("Internal error")
|| body.contains("INTERNAL")
|| body.contains("overloaded")
|| body.contains("UNAVAILABLE")
|| body.contains("crashed")
|| body.contains("Exit code:")
|| body.contains("out of memory")
|| body.contains("Out of memory")
|| body.contains("OOMKilled")
|| body.contains("segmentation fault")
|| body.contains("Segmentation fault")
|| body.contains("killed by signal")
|| body.contains("Killed by signal");
if !is_fault {
return None;
}
Some(ProviderError::UpstreamServerFault(body.to_string()))
}
fn parse_token_counts(body: &str) -> (u32, u32) {
let max = extract_u32_after(body, "allowed (")
.or_else(|| extract_u32_after(body, "limit of "))
.unwrap_or(0);
let req = extract_u32_after(body, "token count (")
.or_else(|| extract_u32_after(body, "resulted in "))
.unwrap_or(0);
(max, req)
}
fn extract_u32_after(body: &str, marker: &str) -> Option<u32> {
let idx = body.find(marker)?;
let after = &body[idx + marker.len()..];
let digits: String = after.chars().take_while(char::is_ascii_digit).collect();
digits.parse().ok()
}
#[cfg(test)]
mod tests {
use super::*;
fn from_400(body: &str) -> ProviderError {
ProviderError::from(GoogleError::BadStatus {
status: 400,
body: body.to_string(),
})
}
fn from_upstream(msg: &str) -> ProviderError {
ProviderError::from(GoogleError::UpstreamError(msg.to_string()))
}
#[test]
fn gemini_400_context_overflow_routes_to_context_too_long() {
let body = r#"{"error":{"code":400,"message":"The input token count (1290020) exceeds the maximum number of tokens allowed (1048575).","status":"INVALID_ARGUMENT"}}"#;
match from_400(body) {
ProviderError::ContextTooLong {
max_tokens,
requested_tokens,
} => {
assert_eq!(max_tokens, 1_048_575);
assert_eq!(requested_tokens, 1_290_020);
}
other => panic!("expected ContextTooLong, got {other:?}"),
}
}
#[test]
fn gemini_400_invalid_argument_stays_invalid_request() {
let body = r#"{"error":{"code":400,"message":"Invalid value at 'generation_config.temperature' (2.5 out of range)","status":"INVALID_ARGUMENT"}}"#;
match from_400(body) {
ProviderError::InvalidRequest(s) => assert!(s.contains("temperature")),
other => panic!("expected InvalidRequest, got {other:?}"),
}
}
#[test]
fn other_status_arms_unchanged() {
assert!(matches!(
ProviderError::from(GoogleError::BadStatus {
status: 401,
body: "nope".into()
}),
ProviderError::Auth(_)
));
assert!(matches!(
ProviderError::from(GoogleError::BadStatus {
status: 429,
body: String::new()
}),
ProviderError::RateLimit { .. }
));
assert!(matches!(
ProviderError::from(GoogleError::BadStatus {
status: 503,
body: "overloaded".into()
}),
ProviderError::ServerError { status: 503, .. }
));
}
#[test]
fn upstream_in_band_context_overflow_routes_to_context_too_long() {
let msg =
"The input token count (5200) exceeds the maximum number of tokens allowed (4096).";
match from_upstream(msg) {
ProviderError::ContextTooLong {
max_tokens,
requested_tokens,
} => {
assert_eq!(max_tokens, 4096);
assert_eq!(requested_tokens, 5200);
}
other => panic!("expected ContextTooLong, got {other:?}"),
}
}
#[test]
fn upstream_in_band_internal_fault_routes_to_server_fault() {
assert!(matches!(
from_upstream("Internal error encountered."),
ProviderError::UpstreamServerFault(_)
));
}
#[test]
fn upstream_in_band_model_crash_routes_to_server_fault() {
assert!(matches!(
from_upstream("model crashed mid-generation"),
ProviderError::UpstreamServerFault(_)
));
}
#[test]
fn upstream_in_band_overloaded_routes_to_server_fault() {
assert!(matches!(
from_upstream("The model is overloaded. Please try again later."),
ProviderError::UpstreamServerFault(_)
));
}
#[test]
fn upstream_in_band_plain_message_stays_invalid_request() {
match from_upstream("function call schema rejected: unknown field 'foo'") {
ProviderError::InvalidRequest(s) => assert!(s.contains("schema rejected")),
other => panic!("expected InvalidRequest, got {other:?}"),
}
}
#[test]
fn context_classifier_wins_over_fault_for_overflow_body() {
let msg = "context_length_exceeded: Input tokens exceed the limit";
assert!(matches!(
from_upstream(msg),
ProviderError::ContextTooLong { .. }
));
}
}