use crate::stream::AssistantMessageEventStream;
use crate::types::*;
use futures::StreamExt;
use reqwest::header::HeaderMap;
use serde::Serialize;
use std::collections::HashMap;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
pub const DEFAULT_MAX_RETRIES: u32 = 2;
pub const DEFAULT_MAX_RETRY_DELAY_MS: u64 = 30_000;
pub const INCOMPLETE_STREAM_ERROR_PREFIX: &str = "[incomplete_stream]";
const RETRY_BASE_DELAY_MS: u64 = 500;
pub fn resolve_base_url<'a>(
options_base_url: Option<&'a str>,
model_base_url: Option<&'a str>,
default: &'a str,
) -> &'a str {
options_base_url.or(model_base_url).unwrap_or(default)
}
pub async fn apply_on_payload<T: Serialize>(
request: &T,
hook: &Option<OnPayloadFn>,
model: &Model,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
if let Some(ref hook) = hook {
let request_json = serde_json::to_value(request)
.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })?;
match hook(request_json.clone(), model.clone()).await {
Some(modified) => serde_json::to_string(&modified)
.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) }),
None => serde_json::to_string(&request_json)
.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) }),
}
} else {
serde_json::to_string(request)
.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })
}
}
pub fn validate_url_or_error(
base: &str,
limits: &SecurityConfig,
output: &mut AssistantMessage,
stream: &AssistantMessageEventStream,
) -> bool {
if let Err(e) = limits.url.validate(base) {
tracing::error!(url = %base, error = %e, "Base URL validation failed");
output.stop_reason = StopReason::Error;
output.error_message = Some(format!("URL validation error: {}", e));
stream.push(AssistantMessageEvent::Error {
reason: StopReason::Error,
error: output.clone(),
});
stream.end(None);
false
} else {
true
}
}
pub fn debug_preview(body: &str, max_len: usize) -> &str {
if body.len() <= max_len {
return body;
}
&body[..body.floor_char_boundary(max_len)]
}
pub fn clamp_openai_max_tokens(max_tokens: Option<u32>) -> Option<u32> {
max_tokens.map(|value| value.max(16))
}
pub fn apply_custom_headers(
headers: &mut HeaderMap,
custom: &Option<HashMap<String, String>>,
policy: &HeaderPolicy,
) {
if let Some(ref custom_headers) = custom {
for (key, value) in custom_headers {
if policy.is_protected(key) {
tracing::warn!(header = %key, "Skipping protected header override");
continue;
}
if let Ok(header_name) = reqwest::header::HeaderName::try_from(key.clone()) {
if let Ok(header_value) = reqwest::header::HeaderValue::try_from(value.clone()) {
headers.insert(header_name, header_value);
}
}
}
}
}
pub async fn handle_error_response(
response: reqwest::Response,
url: &str,
model: &Model,
limits: &SecurityConfig,
output: &mut AssistantMessage,
stream: &AssistantMessageEventStream,
provider_name: &str,
request_body: &str,
) {
let status = response.status();
let body = crate::types::read_error_body(response, limits.http.max_error_body_bytes).await;
tracing::error!(
url = %url,
model = %model.id,
status = %status,
response_body = %body,
"{} request failed", provider_name
);
if status.is_client_error() {
tracing::warn!(
url = %url,
model = %model.id,
status = %status,
request_body = %request_body,
"{} client error request body dump", provider_name
);
}
output.stop_reason = StopReason::Error;
output.error_message = Some(crate::types::truncate_error_message(
&format!("HTTP {}: {}", status, body),
limits.http.max_error_message_chars,
));
stream.push(AssistantMessageEvent::Error {
reason: StopReason::Error,
error: output.clone(),
});
stream.end(None);
}
pub fn check_sse_buffer_overflow(
buffer_len: usize,
max_bytes: usize,
output: &mut AssistantMessage,
stream: &AssistantMessageEventStream,
) -> bool {
if buffer_len > max_bytes {
tracing::error!(
buffer_size = buffer_len,
limit = max_bytes,
"SSE line buffer exceeded limit, aborting stream"
);
output.stop_reason = StopReason::Error;
output.error_message = Some("SSE line buffer exceeded maximum size".to_string());
stream.push(AssistantMessageEvent::Error {
reason: StopReason::Error,
error: output.clone(),
});
stream.end(None);
true
} else {
false
}
}
pub fn emit_aborted(output: &mut AssistantMessage, stream: &AssistantMessageEventStream) {
output.stop_reason = StopReason::Aborted;
output.error_message = Some("Aborted".to_string());
stream.push(AssistantMessageEvent::Error {
reason: StopReason::Aborted,
error: output.clone(),
});
stream.end(None);
}
pub async fn send_request_with_cancel(
request: reqwest::RequestBuilder,
cancel_token: Option<&CancellationToken>,
output: &mut AssistantMessage,
stream: &AssistantMessageEventStream,
) -> Result<Option<reqwest::Response>, reqwest::Error> {
if let Some(cancel_token) = cancel_token {
tokio::select! {
_ = cancel_token.cancelled() => {
emit_aborted(output, stream);
Ok(None)
}
response = request.send() => response.map(Some),
}
} else {
request.send().await.map(Some)
}
}
pub async fn next_stream_item_with_cancel<S, T, E>(
source: &mut S,
cancel_token: Option<&CancellationToken>,
output: &mut AssistantMessage,
stream: &AssistantMessageEventStream,
) -> Option<Result<T, E>>
where
S: futures::Stream<Item = Result<T, E>> + Unpin,
{
if let Some(cancel_token) = cancel_token {
tokio::select! {
_ = cancel_token.cancelled() => {
emit_aborted(output, stream);
None
}
item = source.next() => item,
}
} else {
source.next().await
}
}
pub fn is_retryable_status(status: reqwest::StatusCode) -> bool {
matches!(status.as_u16(), 408 | 429 | 500 | 502 | 503 | 504)
}
pub fn is_retryable_error(err: &reqwest::Error) -> bool {
err.is_timeout() || err.is_connect()
}
pub fn is_retryable_stream_error(err: &reqwest::Error) -> bool {
is_retryable_error(err) || err.is_body()
}
pub fn parse_retry_after(response: &reqwest::Response) -> Option<Duration> {
let value = response.headers().get("retry-after")?.to_str().ok()?;
if let Ok(seconds) = value.parse::<u64>() {
return Some(Duration::from_secs(seconds));
}
if let Ok(date) = httpdate::parse_http_date(value) {
let now = std::time::SystemTime::now();
if let Ok(delta) = date.duration_since(now) {
return Some(delta);
}
return Some(Duration::ZERO);
}
None
}
pub fn compute_retry_delay(attempt: u32, max_delay_ms: u64) -> Duration {
let exp_delay = RETRY_BASE_DELAY_MS.saturating_mul(1u64 << attempt.min(10));
let capped = if max_delay_ms == 0 {
exp_delay
} else {
exp_delay.min(max_delay_ms)
};
let jitter_range = capped / 4;
let jitter = if jitter_range > 0 {
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos() as u64;
nanos % jitter_range
} else {
0
};
Duration::from_millis(capped + jitter)
}
pub fn cap_retry_delay(delay: Duration, max_delay_ms: u64) -> Duration {
if max_delay_ms == 0 {
delay
} else {
delay.min(Duration::from_millis(max_delay_ms))
}
}
pub async fn sleep_with_cancel(
duration: Duration,
cancel_token: Option<&CancellationToken>,
) -> bool {
if let Some(cancel_token) = cancel_token {
tokio::select! {
_ = cancel_token.cancelled() => true,
_ = tokio::time::sleep(duration) => false,
}
} else {
tokio::time::sleep(duration).await;
false
}
}
pub fn emit_terminal_error(
output: &mut AssistantMessage,
error_message: impl Into<String>,
max_error_message_chars: usize,
stream: &AssistantMessageEventStream,
) {
if output.content.is_empty() {
output.content = vec![ContentBlock::Text(TextContent::new(""))];
}
output.stop_reason = StopReason::Error;
output.error_message = Some(crate::types::truncate_error_message(
&error_message.into(),
max_error_message_chars,
));
stream.push(AssistantMessageEvent::Error {
reason: StopReason::Error,
error: output.clone(),
});
stream.end(None);
}
pub fn emit_incomplete_stream_error(
output: &mut AssistantMessage,
provider: &str,
detail: impl Into<String>,
max_error_message_chars: usize,
stream: &AssistantMessageEventStream,
) {
let detail = detail.into();
emit_terminal_error(
output,
format!(
"{INCOMPLETE_STREAM_ERROR_PREFIX}{provider}: {}",
crate::types::truncate_error_message(&detail, max_error_message_chars)
),
max_error_message_chars,
stream,
);
}
pub fn parse_incomplete_stream_error(error_message: &str) -> Option<(String, String)> {
let payload = error_message.strip_prefix(INCOMPLETE_STREAM_ERROR_PREFIX)?;
let (provider, detail) = payload.split_once(':')?;
Some((provider.trim().to_string(), detail.trim().to_string()))
}
pub fn emit_background_task_error(
model: &Model,
fallback_api: Api,
error_message: impl Into<String>,
stream: &AssistantMessageEventStream,
) {
if stream.is_done() {
return;
}
let mut output = AssistantMessage::builder()
.api(model.api.clone().unwrap_or(fallback_api))
.provider(model.provider.clone())
.model(model.id.clone())
.usage(Usage::default())
.stop_reason(StopReason::Error)
.build()
.expect("background task error message should be buildable");
output.content = vec![ContentBlock::Text(TextContent::new(""))];
emit_terminal_error(&mut output, error_message, 4096, stream);
}
pub async fn send_request_with_retry(
client: &reqwest::Client,
url: &str,
headers: HeaderMap,
body: String,
timeout: Duration,
max_retries: u32,
max_retry_delay_ms: u64,
cancel_token: Option<&CancellationToken>,
output: &mut AssistantMessage,
stream: &AssistantMessageEventStream,
) -> Result<Option<reqwest::Response>, reqwest::Error> {
let mut attempt: u32 = 0;
loop {
let request = client
.post(url)
.timeout(timeout)
.headers(headers.clone())
.body(body.clone());
match send_request_with_cancel(request, cancel_token, output, stream).await {
Ok(None) => {
return Ok(None);
}
Ok(Some(response)) => {
if is_retryable_status(response.status()) && attempt < max_retries {
let delay = parse_retry_after(&response)
.map(|d| cap_retry_delay(d, max_retry_delay_ms))
.unwrap_or_else(|| compute_retry_delay(attempt, max_retry_delay_ms));
tracing::warn!(
url = %url,
status = %response.status(),
attempt = attempt + 1,
max_retries = max_retries,
delay_ms = delay.as_millis() as u64,
"Retryable HTTP status, backing off before retry"
);
if sleep_with_cancel(delay, cancel_token).await {
emit_aborted(output, stream);
return Ok(None);
}
attempt += 1;
continue;
}
return Ok(Some(response));
}
Err(e) => {
if is_retryable_error(&e) && attempt < max_retries {
let delay = compute_retry_delay(attempt, max_retry_delay_ms);
tracing::warn!(
url = %url,
error = %e,
attempt = attempt + 1,
max_retries = max_retries,
delay_ms = delay.as_millis() as u64,
"Retryable transport error, backing off before retry"
);
if sleep_with_cancel(delay, cancel_token).await {
emit_aborted(output, stream);
return Ok(None);
}
attempt += 1;
continue;
}
return Err(e);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_retry_delay_zero_cap_disables_capping() {
let delay = compute_retry_delay(1, 0);
assert!(
delay >= Duration::from_millis(RETRY_BASE_DELAY_MS * 2),
"zero cap should not collapse retry delay to zero"
);
}
#[test]
fn test_cap_retry_delay_zero_cap_is_unbounded() {
let delay = Duration::from_secs(5);
assert_eq!(cap_retry_delay(delay, 0), delay);
}
}