use std::time::Duration;
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use genai::Client;
use genai::chat::ReasoningEffort as GenaiReasoningEffort;
use genai::chat::{ChatOptions, ChatStreamEvent};
use reqwest::StatusCode;
use awaken_contract::contract::content::ContentBlock;
use awaken_contract::contract::executor::{
InferenceExecutionError, InferenceRequest, InferenceStream, LlmExecutor, LlmStreamEvent,
};
use awaken_contract::contract::inference::{
ReasoningEffort as ContractReasoningEffort, StopReason, StreamResult,
};
use super::convert::{build_chat_request, from_genai_tool_call, map_stop_reason, map_usage};
use super::streaming::{StreamCollector, StreamOutput};
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(120);
fn map_reasoning_effort(effort: &ContractReasoningEffort) -> GenaiReasoningEffort {
match effort {
ContractReasoningEffort::None => GenaiReasoningEffort::None,
ContractReasoningEffort::Low => GenaiReasoningEffort::Low,
ContractReasoningEffort::Medium => GenaiReasoningEffort::Medium,
ContractReasoningEffort::High => GenaiReasoningEffort::High,
ContractReasoningEffort::Max => GenaiReasoningEffort::Max,
ContractReasoningEffort::Budget(n) => GenaiReasoningEffort::Budget(*n),
}
}
fn stream_output_to_llm_event(output: StreamOutput) -> Option<LlmStreamEvent> {
match output {
StreamOutput::TextDelta(delta) => Some(LlmStreamEvent::TextDelta(delta)),
StreamOutput::ReasoningDelta(delta) => Some(LlmStreamEvent::ReasoningDelta(delta)),
StreamOutput::ToolCallStart { id, name } => {
Some(LlmStreamEvent::ToolCallStart { id, name })
}
StreamOutput::ToolCallDelta { id, args_delta } => {
Some(LlmStreamEvent::ToolCallDelta { id, args_delta })
}
StreamOutput::None => None,
}
}
async fn next_chat_stream_event<S>(
stream: &mut S,
timeout_dur: Duration,
) -> Result<Option<Result<ChatStreamEvent, genai::Error>>, InferenceExecutionError>
where
S: Stream<Item = Result<ChatStreamEvent, genai::Error>> + Unpin,
{
tokio::time::timeout(timeout_dur, stream.next())
.await
.map_err(|_| {
InferenceExecutionError::Timeout(format!(
"stream idle timeout after {}s",
timeout_dur.as_secs()
))
})
}
pub struct GenaiExecutor {
client: Client,
default_options: Option<ChatOptions>,
default_timeout: Duration,
}
impl GenaiExecutor {
pub fn new() -> Self {
Self {
client: Client::default(),
default_options: None,
default_timeout: DEFAULT_TIMEOUT,
}
}
pub fn with_client(client: Client) -> Self {
Self {
client,
default_options: None,
default_timeout: DEFAULT_TIMEOUT,
}
}
#[must_use]
pub fn with_options(mut self, options: ChatOptions) -> Self {
self.default_options = Some(options);
self
}
#[must_use]
pub fn with_timeout(mut self, duration: Duration) -> Self {
self.default_timeout = duration;
self
}
fn build_options(&self, request: &InferenceRequest) -> ChatOptions {
let mut opts = self
.default_options
.clone()
.unwrap_or_default()
.with_capture_usage(true)
.with_capture_content(true)
.with_capture_tool_calls(true);
if let Some(ref ovr) = request.overrides {
if let Some(temp) = ovr.temperature {
opts = opts.with_temperature(temp);
}
if let Some(max) = ovr.max_tokens {
opts = opts.with_max_tokens(max);
}
if let Some(top_p) = ovr.top_p {
opts = opts.with_top_p(top_p);
}
if let Some(ref effort) = ovr.reasoning_effort {
opts = opts.with_reasoning_effort(map_reasoning_effort(effort));
opts = opts.with_capture_reasoning_content(true);
}
}
opts
}
fn map_error(e: genai::Error) -> InferenceExecutionError {
tracing::warn!(error = ?e, "LLM inference error");
let parts = Self::extract_structured_parts(&e);
let msg = format!("{e:#}");
if let Some((status, body, retry_after)) = parts {
return Self::classify_status(status, &msg, body.as_deref(), retry_after);
}
let lower = msg.to_lowercase();
if lower.contains("content_filter")
|| lower.contains("content policy")
|| lower.contains("content_policy_violation")
|| lower.contains("blocked by safety")
{
InferenceExecutionError::ContentFiltered(msg)
} else if lower.contains("overloaded") {
InferenceExecutionError::overloaded(msg)
} else if lower.contains("rate")
|| lower.contains("429")
|| lower.contains("too many requests")
{
InferenceExecutionError::rate_limited(msg)
} else if lower.contains("timeout") || lower.contains("timed out") {
InferenceExecutionError::Timeout(msg)
} else if lower.contains("503") || lower.contains("502") || lower.contains("500") {
InferenceExecutionError::Provider(msg)
} else {
tracing::warn!(error_msg = %msg, "unclassified LLM error — consider adding a pattern");
InferenceExecutionError::Provider(msg)
}
}
fn classify_status(
status: StatusCode,
msg: &str,
body: Option<&str>,
retry_after: Option<Duration>,
) -> InferenceExecutionError {
match status.as_u16() {
429 => InferenceExecutionError::RateLimited {
message: msg.to_string(),
retry_after,
},
529 | 503 => InferenceExecutionError::Overloaded {
message: msg.to_string(),
retry_after,
},
408 | 504 => InferenceExecutionError::Timeout(msg.to_string()),
500 | 502 => InferenceExecutionError::Provider(msg.to_string()),
400 => {
if Self::looks_like_context_overflow(body, msg) {
InferenceExecutionError::ContextOverflow(msg.to_string())
} else {
InferenceExecutionError::InvalidRequest(msg.to_string())
}
}
401 | 403 => InferenceExecutionError::Unauthorized(msg.to_string()),
404 => InferenceExecutionError::ModelNotFound(msg.to_string()),
413 => InferenceExecutionError::ContextOverflow(msg.to_string()),
422 => InferenceExecutionError::InvalidRequest(msg.to_string()),
_ => InferenceExecutionError::Provider(msg.to_string()),
}
}
fn looks_like_context_overflow(body: Option<&str>, msg: &str) -> bool {
const NEEDLES: &[&str] = &[
"prompt is too long",
"context_length_exceeded",
"context length",
"input is too long",
"maximum context length",
"reduce the length",
"too many tokens",
"request too large",
];
let haystack_lower = match body {
Some(b) if !b.is_empty() => b.to_lowercase(),
_ => msg.to_lowercase(),
};
NEEDLES.iter().any(|needle| haystack_lower.contains(needle))
}
fn extract_structured_parts(
e: &genai::Error,
) -> Option<(StatusCode, Option<String>, Option<Duration>)> {
match e {
genai::Error::HttpError { status, .. } => Some((*status, None, None)),
genai::Error::WebAdapterCall { webc_error, .. }
| genai::Error::WebModelCall { webc_error, .. } => match webc_error {
genai::webc::Error::ResponseFailedStatus {
status,
body,
headers,
} => {
let retry = parse_retry_after(headers);
Some((*status, Some(body.clone()), retry))
}
_ => None,
},
_ => None,
}
}
}
fn parse_retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
let value = headers.get(reqwest::header::RETRY_AFTER)?;
let text = value.to_str().ok()?.trim();
let seconds: u64 = text.parse().ok()?;
Some(Duration::from_secs(seconds))
}
impl Default for GenaiExecutor {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl LlmExecutor for GenaiExecutor {
async fn execute(
&self,
request: InferenceRequest,
) -> Result<StreamResult, InferenceExecutionError> {
let model = request.upstream_model.clone();
let tools: Vec<_> = request.tools.clone();
let chat_req = build_chat_request(
&request.system,
&request.messages,
&tools,
request.enable_prompt_cache,
);
let opts = self.build_options(&request);
let timeout_dur = self.default_timeout;
let response = tokio::time::timeout(
timeout_dur,
self.client.exec_chat(&model, chat_req, Some(&opts)),
)
.await
.map_err(|_| {
InferenceExecutionError::Timeout(format!(
"inference timeout after {}s",
timeout_dur.as_secs()
))
})?
.map_err(Self::map_error)?;
let text = response.content.first_text().unwrap_or("").to_string();
let tool_calls: Vec<_> = response
.content
.tool_calls()
.into_iter()
.map(from_genai_tool_call)
.collect();
let usage = Some(map_usage(&response.usage));
let stop_reason = response.stop_reason.as_ref().and_then(map_stop_reason);
let content = if text.is_empty() {
vec![]
} else {
vec![ContentBlock::text(text)]
};
Ok(StreamResult {
content,
tool_calls,
usage,
stop_reason,
has_incomplete_tool_calls: false,
})
}
fn execute_stream(
&self,
request: InferenceRequest,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<InferenceStream, InferenceExecutionError>>
+ Send
+ '_,
>,
> {
Box::pin(async move {
let model = request.upstream_model.clone();
let tools: Vec<_> = request.tools.clone();
let chat_req = build_chat_request(
&request.system,
&request.messages,
&tools,
request.enable_prompt_cache,
);
let mut opts = self.build_options(&request);
opts = opts.with_capture_content(true);
let timeout_dur = self.default_timeout;
let stream_response = tokio::time::timeout(
timeout_dur,
self.client.exec_chat_stream(&model, chat_req, Some(&opts)),
)
.await
.map_err(|_| {
InferenceExecutionError::Timeout(format!(
"inference timeout after {}s",
timeout_dur.as_secs()
))
})?
.map_err(Self::map_error)?;
let event_stream = futures::stream::unfold(
(stream_response.stream, StreamCollector::new()),
move |(mut stream, mut collector)| async move {
if let Some(output) = collector.take_pending_output() {
let event = stream_output_to_llm_event(output)
.expect("pending outputs are never empty");
return Some((Ok(event), (stream, collector)));
}
if collector.end_seen() {
let result = collector.finish();
let stop = result.stop_reason.unwrap_or(StopReason::EndTurn);
return Some((
Ok(LlmStreamEvent::Stop(stop)),
(stream, StreamCollector::new()),
));
}
loop {
match next_chat_stream_event(&mut stream, timeout_dur).await {
Ok(Some(Ok(event))) => {
let is_end = matches!(event, ChatStreamEvent::End(_));
let output = collector.process(event);
if let Some(event) = stream_output_to_llm_event(output) {
return Some((Ok(event), (stream, collector)));
}
if is_end {
if let Some(usage) = collector.take_usage() {
return Some((
Ok(LlmStreamEvent::Usage(usage)),
(stream, collector),
));
}
let result = collector.finish();
let stop = result.stop_reason.unwrap_or(StopReason::EndTurn);
return Some((
Ok(LlmStreamEvent::Stop(stop)),
(stream, StreamCollector::new()),
));
}
continue;
}
Ok(Some(Err(e))) => {
return Some((Err(Self::map_error(e)), (stream, collector)));
}
Ok(None) => return None,
Err(e) => return Some((Err(e), (stream, collector))),
}
}
},
);
Ok(Box::pin(event_stream) as InferenceStream)
})
}
fn name(&self) -> &str {
"genai"
}
}
#[cfg(test)]
mod tests {
use super::*;
use awaken_contract::contract::executor::InferenceRequest;
use awaken_contract::contract::inference::InferenceOverride;
use awaken_contract::contract::message::Message;
fn make_request(overrides: Option<InferenceOverride>) -> InferenceRequest {
InferenceRequest {
upstream_model: "test-model".into(),
messages: vec![Message::user("hello")],
tools: vec![],
system: vec![],
overrides,
enable_prompt_cache: false,
}
}
#[test]
fn new_creates_executor() {
let exec = GenaiExecutor::new();
assert!(exec.default_options.is_none());
}
#[test]
fn default_creates_executor() {
let exec = GenaiExecutor::default();
assert!(exec.default_options.is_none());
}
#[test]
fn name_returns_genai() {
let exec = GenaiExecutor::new();
assert_eq!(exec.name(), "genai");
}
#[test]
fn build_options_defaults() {
let exec = GenaiExecutor::new();
let req = make_request(None);
let opts = exec.build_options(&req);
assert_eq!(opts.capture_usage, Some(true));
assert_eq!(opts.capture_content, Some(true));
assert_eq!(opts.capture_tool_calls, Some(true));
assert_eq!(opts.temperature, None);
assert_eq!(opts.max_tokens, None);
assert_eq!(opts.top_p, None);
}
#[test]
fn build_options_with_temperature() {
let exec = GenaiExecutor::new();
let req = make_request(Some(InferenceOverride {
temperature: Some(0.5),
..Default::default()
}));
let opts = exec.build_options(&req);
assert_eq!(opts.temperature, Some(0.5));
assert_eq!(opts.max_tokens, None);
assert_eq!(opts.top_p, None);
}
#[test]
fn build_options_with_max_tokens() {
let exec = GenaiExecutor::new();
let req = make_request(Some(InferenceOverride {
max_tokens: Some(1024),
..Default::default()
}));
let opts = exec.build_options(&req);
assert_eq!(opts.max_tokens, Some(1024));
assert_eq!(opts.temperature, None);
assert_eq!(opts.top_p, None);
}
#[test]
fn build_options_with_top_p() {
let exec = GenaiExecutor::new();
let req = make_request(Some(InferenceOverride {
top_p: Some(0.9),
..Default::default()
}));
let opts = exec.build_options(&req);
assert_eq!(opts.top_p, Some(0.9));
assert_eq!(opts.temperature, None);
assert_eq!(opts.max_tokens, None);
}
#[test]
fn build_options_with_all_overrides() {
let exec = GenaiExecutor::new();
let req = make_request(Some(InferenceOverride {
temperature: Some(0.7),
max_tokens: Some(2048),
top_p: Some(0.95),
..Default::default()
}));
let opts = exec.build_options(&req);
assert_eq!(opts.temperature, Some(0.7));
assert_eq!(opts.max_tokens, Some(2048));
assert_eq!(opts.top_p, Some(0.95));
assert_eq!(opts.capture_usage, Some(true));
assert_eq!(opts.capture_content, Some(true));
assert_eq!(opts.capture_tool_calls, Some(true));
}
#[test]
fn build_options_with_default_options() {
let base = ChatOptions::default()
.with_temperature(0.3)
.with_max_tokens(512);
let exec = GenaiExecutor::new().with_options(base);
let req = make_request(Some(InferenceOverride {
temperature: Some(0.9),
..Default::default()
}));
let opts = exec.build_options(&req);
assert_eq!(opts.temperature, Some(0.9));
assert_eq!(opts.max_tokens, Some(512));
assert_eq!(opts.capture_usage, Some(true));
}
#[test]
fn build_options_with_reasoning_effort() {
let exec = GenaiExecutor::new();
let req = make_request(Some(InferenceOverride {
reasoning_effort: Some(ContractReasoningEffort::High),
..Default::default()
}));
let opts = exec.build_options(&req);
assert!(
opts.reasoning_effort.is_some(),
"reasoning_effort should be set"
);
assert_eq!(opts.capture_reasoning_content, Some(true));
}
#[test]
fn build_options_without_reasoning_effort() {
let exec = GenaiExecutor::new();
let req = make_request(None);
let opts = exec.build_options(&req);
assert!(opts.reasoning_effort.is_none());
assert!(opts.capture_reasoning_content.is_none());
}
#[test]
fn map_reasoning_effort_all_variants() {
use super::map_reasoning_effort;
assert!(matches!(
map_reasoning_effort(&ContractReasoningEffort::None),
GenaiReasoningEffort::None
));
assert!(matches!(
map_reasoning_effort(&ContractReasoningEffort::Low),
GenaiReasoningEffort::Low
));
assert!(matches!(
map_reasoning_effort(&ContractReasoningEffort::Medium),
GenaiReasoningEffort::Medium
));
assert!(matches!(
map_reasoning_effort(&ContractReasoningEffort::High),
GenaiReasoningEffort::High
));
assert!(matches!(
map_reasoning_effort(&ContractReasoningEffort::Max),
GenaiReasoningEffort::Max
));
assert!(matches!(
map_reasoning_effort(&ContractReasoningEffort::Budget(4096)),
GenaiReasoningEffort::Budget(4096)
));
}
#[test]
fn map_error_rate_limited_429() {
let err = genai::Error::Internal("server returned 429".into());
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::RateLimited { .. }),
"expected RateLimited, got {mapped:?}"
);
}
#[test]
fn map_error_rate_word() {
let err = genai::Error::Internal("rate limit exceeded".into());
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::RateLimited { .. }),
"expected RateLimited, got {mapped:?}"
);
}
#[test]
fn map_error_timeout() {
let err = genai::Error::Internal("connection timeout".into());
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::Timeout(_)),
"expected Timeout, got {mapped:?}"
);
}
#[test]
fn map_error_timed_out() {
let err = genai::Error::Internal("request timed out".into());
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::Timeout(_)),
"expected Timeout, got {mapped:?}"
);
}
#[test]
fn map_error_generic() {
let err = genai::Error::Internal("something went wrong".into());
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::Provider(_)),
"expected Provider, got {mapped:?}"
);
}
#[test]
fn map_error_too_many_requests() {
let err = genai::Error::Internal("Too Many Requests".into());
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::RateLimited { .. }),
"expected RateLimited, got {mapped:?}"
);
}
#[test]
fn map_error_overloaded() {
let err = genai::Error::Internal("server overloaded".into());
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::Overloaded { .. }),
"expected Overloaded, got {mapped:?}"
);
}
#[test]
fn map_error_503_string() {
let err = genai::Error::Internal("503 Service Unavailable".into());
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::Provider(_)),
"expected Provider, got {mapped:?}"
);
}
#[test]
fn map_error_http_429_structured() {
let err = genai::Error::HttpError {
status: StatusCode::TOO_MANY_REQUESTS,
canonical_reason: "Too Many Requests".into(),
body: "rate limited".into(),
};
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::RateLimited { .. }),
"expected RateLimited, got {mapped:?}"
);
}
#[test]
fn map_error_http_500_structured() {
let err = genai::Error::HttpError {
status: StatusCode::INTERNAL_SERVER_ERROR,
canonical_reason: "Internal Server Error".into(),
body: "oops".into(),
};
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::Provider(_)),
"expected Provider, got {mapped:?}"
);
}
#[test]
fn map_error_http_504_structured() {
let err = genai::Error::HttpError {
status: StatusCode::GATEWAY_TIMEOUT,
canonical_reason: "Gateway Timeout".into(),
body: "timeout".into(),
};
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::Timeout(_)),
"expected Timeout, got {mapped:?}"
);
}
#[test]
fn map_error_preserves_full_chain() {
let err = genai::Error::Internal("rate limit exceeded".into());
let mapped = GenaiExecutor::map_error(err);
let msg = match mapped {
InferenceExecutionError::RateLimited { message, .. } => message,
other => panic!("expected RateLimited, got {other:?}"),
};
assert!(msg.contains("rate limit exceeded"), "msg was: {msg}");
}
#[test]
fn with_timeout_builder() {
let exec = GenaiExecutor::new().with_timeout(Duration::from_secs(30));
assert_eq!(exec.default_timeout, Duration::from_secs(30));
}
#[tokio::test(start_paused = true)]
async fn timeout_fires_for_slow_future() {
let timeout_dur = Duration::from_secs(120);
let slow = async {
tokio::time::sleep(Duration::from_secs(200)).await;
Ok::<&str, String>("should not reach")
};
let result = tokio::time::timeout(timeout_dur, slow).await;
assert!(result.is_err(), "should have timed out");
}
#[tokio::test(start_paused = true)]
async fn timeout_maps_to_inference_timeout_error() {
let timeout_dur = Duration::from_millis(50);
let slow = async {
tokio::time::sleep(Duration::from_secs(10)).await;
Ok::<(), ()>(())
};
let result = tokio::time::timeout(timeout_dur, slow).await;
assert!(result.is_err());
let mapped = result.map_err(|_| {
InferenceExecutionError::Timeout(format!(
"inference timeout after {}s",
timeout_dur.as_secs()
))
});
assert!(
matches!(mapped, Err(InferenceExecutionError::Timeout(ref msg)) if msg.contains("timeout"))
);
}
#[tokio::test(start_paused = true)]
async fn timeout_does_not_fire_for_fast_future() {
let timeout_dur = Duration::from_secs(120);
let fast = async {
tokio::time::sleep(Duration::from_millis(10)).await;
Ok::<&str, String>("done")
};
let result = tokio::time::timeout(timeout_dur, fast).await;
assert!(result.is_ok(), "fast future should not time out");
assert_eq!(result.unwrap().unwrap(), "done");
}
#[tokio::test(start_paused = true)]
async fn stream_next_timeout_maps_to_inference_timeout_error() {
let mut stream = futures::stream::pending::<Result<ChatStreamEvent, genai::Error>>();
let result = next_chat_stream_event(&mut stream, Duration::from_millis(50)).await;
assert!(
matches!(result, Err(InferenceExecutionError::Timeout(ref msg)) if msg.contains("stream idle timeout")),
"expected stream idle timeout, got {result:?}"
);
}
#[tokio::test(start_paused = true)]
async fn stream_next_timeout_does_not_fire_for_closed_stream() {
let mut stream = futures::stream::empty::<Result<ChatStreamEvent, genai::Error>>();
let result = next_chat_stream_event(&mut stream, Duration::from_secs(120)).await;
assert!(matches!(result, Ok(None)));
}
#[test]
fn map_error_web_adapter_call_429() {
use genai::adapter::AdapterKind;
use reqwest::header::HeaderMap;
let err = genai::Error::WebAdapterCall {
adapter_kind: AdapterKind::OpenAI,
webc_error: genai::webc::Error::ResponseFailedStatus {
status: StatusCode::TOO_MANY_REQUESTS,
body: "rate limited".into(),
headers: Box::new(HeaderMap::new()),
},
};
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::RateLimited { .. }),
"expected RateLimited, got {mapped:?}"
);
}
#[test]
fn map_error_web_model_call_429() {
use genai::ModelIden;
use genai::adapter::AdapterKind;
use reqwest::header::HeaderMap;
let err = genai::Error::WebModelCall {
model_iden: ModelIden::new(AdapterKind::OpenAI, "gpt-4o"),
webc_error: genai::webc::Error::ResponseFailedStatus {
status: StatusCode::TOO_MANY_REQUESTS,
body: "rate limited".into(),
headers: Box::new(HeaderMap::new()),
},
};
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::RateLimited { .. }),
"expected RateLimited, got {mapped:?}"
);
}
#[test]
fn map_error_web_adapter_call_503_is_overloaded() {
use genai::adapter::AdapterKind;
use reqwest::header::HeaderMap;
let err = genai::Error::WebAdapterCall {
adapter_kind: AdapterKind::Anthropic,
webc_error: genai::webc::Error::ResponseFailedStatus {
status: StatusCode::SERVICE_UNAVAILABLE,
body: "overloaded".into(),
headers: Box::new(HeaderMap::new()),
},
};
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::Overloaded { .. }),
"expected Overloaded, got {mapped:?}"
);
}
#[test]
fn map_error_web_adapter_call_529_is_overloaded() {
use genai::adapter::AdapterKind;
use reqwest::header::HeaderMap;
let err = genai::Error::WebAdapterCall {
adapter_kind: AdapterKind::Anthropic,
webc_error: genai::webc::Error::ResponseFailedStatus {
status: StatusCode::from_u16(529).unwrap(),
body: r#"{"type":"error","error":{"type":"overloaded_error"}}"#.into(),
headers: Box::new(HeaderMap::new()),
},
};
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::Overloaded { .. }),
"expected Overloaded, got {mapped:?}"
);
}
#[test]
fn map_error_http_400_prompt_too_long_is_context_overflow() {
use genai::adapter::AdapterKind;
use reqwest::header::HeaderMap;
let err = genai::Error::WebAdapterCall {
adapter_kind: AdapterKind::Anthropic,
webc_error: genai::webc::Error::ResponseFailedStatus {
status: StatusCode::BAD_REQUEST,
body: r#"{"error":{"message":"prompt is too long: 210000 tokens"}}"#.into(),
headers: Box::new(HeaderMap::new()),
},
};
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::ContextOverflow(_)),
"expected ContextOverflow, got {mapped:?}"
);
}
#[test]
fn map_error_http_400_context_length_exceeded_is_context_overflow() {
use genai::adapter::AdapterKind;
use reqwest::header::HeaderMap;
let err = genai::Error::WebModelCall {
model_iden: genai::ModelIden::new(AdapterKind::OpenAI, "gpt-4o"),
webc_error: genai::webc::Error::ResponseFailedStatus {
status: StatusCode::BAD_REQUEST,
body: r#"{"error":{"code":"context_length_exceeded"}}"#.into(),
headers: Box::new(HeaderMap::new()),
},
};
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::ContextOverflow(_)),
"expected ContextOverflow, got {mapped:?}"
);
}
#[test]
fn map_error_http_400_generic_is_invalid_request() {
use genai::adapter::AdapterKind;
use reqwest::header::HeaderMap;
let err = genai::Error::WebAdapterCall {
adapter_kind: AdapterKind::OpenAI,
webc_error: genai::webc::Error::ResponseFailedStatus {
status: StatusCode::BAD_REQUEST,
body: r#"{"error":"messages must be a non-empty array"}"#.into(),
headers: Box::new(HeaderMap::new()),
},
};
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::InvalidRequest(_)),
"expected InvalidRequest, got {mapped:?}"
);
}
#[test]
fn map_error_http_401_is_unauthorized() {
use genai::adapter::AdapterKind;
use reqwest::header::HeaderMap;
let err = genai::Error::WebAdapterCall {
adapter_kind: AdapterKind::OpenAI,
webc_error: genai::webc::Error::ResponseFailedStatus {
status: StatusCode::UNAUTHORIZED,
body: "bad api key".into(),
headers: Box::new(HeaderMap::new()),
},
};
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::Unauthorized(_)),
"expected Unauthorized, got {mapped:?}"
);
}
#[test]
fn map_error_http_404_is_model_not_found() {
use genai::adapter::AdapterKind;
use reqwest::header::HeaderMap;
let err = genai::Error::WebAdapterCall {
adapter_kind: AdapterKind::OpenAI,
webc_error: genai::webc::Error::ResponseFailedStatus {
status: StatusCode::NOT_FOUND,
body: "no such model".into(),
headers: Box::new(HeaderMap::new()),
},
};
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::ModelNotFound(_)),
"expected ModelNotFound, got {mapped:?}"
);
}
#[test]
fn map_error_http_413_is_context_overflow() {
use genai::adapter::AdapterKind;
use reqwest::header::HeaderMap;
let err = genai::Error::WebAdapterCall {
adapter_kind: AdapterKind::OpenAI,
webc_error: genai::webc::Error::ResponseFailedStatus {
status: StatusCode::PAYLOAD_TOO_LARGE,
body: "too big".into(),
headers: Box::new(HeaderMap::new()),
},
};
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::ContextOverflow(_)),
"expected ContextOverflow, got {mapped:?}"
);
}
#[test]
fn map_error_http_422_is_invalid_request() {
use genai::adapter::AdapterKind;
use reqwest::header::HeaderMap;
let err = genai::Error::WebAdapterCall {
adapter_kind: AdapterKind::OpenAI,
webc_error: genai::webc::Error::ResponseFailedStatus {
status: StatusCode::UNPROCESSABLE_ENTITY,
body: "schema violation".into(),
headers: Box::new(HeaderMap::new()),
},
};
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::InvalidRequest(_)),
"expected InvalidRequest, got {mapped:?}"
);
}
#[test]
fn map_error_retry_after_seconds_header_is_parsed() {
use genai::adapter::AdapterKind;
use reqwest::header::{HeaderMap, HeaderValue, RETRY_AFTER};
let mut headers = HeaderMap::new();
headers.insert(RETRY_AFTER, HeaderValue::from_static("42"));
let err = genai::Error::WebAdapterCall {
adapter_kind: AdapterKind::Anthropic,
webc_error: genai::webc::Error::ResponseFailedStatus {
status: StatusCode::TOO_MANY_REQUESTS,
body: "slow down".into(),
headers: Box::new(headers),
},
};
let mapped = GenaiExecutor::map_error(err);
match mapped {
InferenceExecutionError::RateLimited { retry_after, .. } => {
assert_eq!(retry_after, Some(Duration::from_secs(42)));
}
other => panic!("expected RateLimited with retry_after, got {other:?}"),
}
}
#[test]
fn map_error_retry_after_absent_yields_none() {
use genai::adapter::AdapterKind;
use reqwest::header::HeaderMap;
let err = genai::Error::WebAdapterCall {
adapter_kind: AdapterKind::Anthropic,
webc_error: genai::webc::Error::ResponseFailedStatus {
status: StatusCode::TOO_MANY_REQUESTS,
body: "no header".into(),
headers: Box::new(HeaderMap::new()),
},
};
let mapped = GenaiExecutor::map_error(err);
assert!(matches!(
mapped,
InferenceExecutionError::RateLimited {
retry_after: None,
..
}
));
}
#[test]
fn map_error_content_filter_string_maps_to_content_filtered() {
let err =
genai::Error::Internal("response blocked by content_filter policy violation".into());
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::ContentFiltered(_)),
"expected ContentFiltered, got {mapped:?}"
);
}
#[test]
fn map_error_content_policy_string_maps_to_content_filtered() {
let err = genai::Error::Internal("content policy triggered".into());
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::ContentFiltered(_)),
"expected ContentFiltered, got {mapped:?}"
);
}
#[test]
fn map_error_safety_string_maps_to_content_filtered() {
let err = genai::Error::Internal("blocked by safety filter".into());
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::ContentFiltered(_)),
"expected ContentFiltered, got {mapped:?}"
);
}
#[test]
fn content_filtered_is_not_retryable_and_does_not_count_toward_breaker() {
let err = InferenceExecutionError::ContentFiltered("policy".into());
assert!(
!err.is_retryable(),
"ContentFiltered must be permanent (no retry)"
);
assert!(
!err.counts_toward_circuit_breaker(),
"ContentFiltered must not increment the breaker"
);
}
#[test]
fn map_error_retry_after_non_numeric_yields_none() {
use genai::adapter::AdapterKind;
use reqwest::header::{HeaderMap, HeaderValue, RETRY_AFTER};
let mut headers = HeaderMap::new();
headers.insert(
RETRY_AFTER,
HeaderValue::from_static("Fri, 31 Dec 1999 23:59:59 GMT"),
);
let err = genai::Error::WebAdapterCall {
adapter_kind: AdapterKind::Anthropic,
webc_error: genai::webc::Error::ResponseFailedStatus {
status: StatusCode::from_u16(529).unwrap(),
body: "overloaded".into(),
headers: Box::new(headers),
},
};
let mapped = GenaiExecutor::map_error(err);
assert!(matches!(
mapped,
InferenceExecutionError::Overloaded {
retry_after: None,
..
}
));
}
#[test]
fn map_error_web_model_call_504() {
use genai::ModelIden;
use genai::adapter::AdapterKind;
use reqwest::header::HeaderMap;
let err = genai::Error::WebModelCall {
model_iden: ModelIden::new(AdapterKind::OpenAI, "gpt-4o"),
webc_error: genai::webc::Error::ResponseFailedStatus {
status: StatusCode::GATEWAY_TIMEOUT,
body: "gateway timeout".into(),
headers: Box::new(HeaderMap::new()),
},
};
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::Timeout(_)),
"expected Timeout, got {mapped:?}"
);
}
#[test]
fn map_error_web_adapter_call_non_status_error_falls_through() {
use genai::adapter::AdapterKind;
let err = genai::Error::WebAdapterCall {
adapter_kind: AdapterKind::OpenAI,
webc_error: genai::webc::Error::ResponseFailedNotJson {
content_type: "text/html".into(),
body: "not json".into(),
},
};
let mapped = GenaiExecutor::map_error(err);
assert!(
matches!(mapped, InferenceExecutionError::Provider(_)),
"expected Provider, got {mapped:?}"
);
}
}