use async_openai::error::{ApiError, OpenAIError};
#[derive(Debug)]
pub enum ChatError {
ApiAuth(String),
ApiRateLimit {
message: String,
retry_after_secs: Option<u64>,
},
ApiBadRequest(String),
ApiServerError { status: u16, message: String },
NetworkTimeout(String),
NetworkError(String),
StreamInterrupted(String),
StreamDeserialize(String),
RequestBuild(String),
HookAborted,
RuntimeFailed(String),
AgentPanic(String),
AbnormalFinish(String),
Other(String),
}
impl ChatError {
#[allow(dead_code)]
pub fn is_auth_error(&self) -> bool {
matches!(self, Self::ApiAuth(_))
}
pub fn display_message(&self) -> String {
match self {
Self::ApiAuth(_) => "API 认证失败,请检查 API Key".to_string(),
Self::ApiRateLimit {
retry_after_secs, ..
} => {
if let Some(secs) = retry_after_secs {
format!("请求过于频繁,请在 {} 秒后重试", secs)
} else {
"请求过于频繁,请稍后重试".to_string()
}
}
Self::ApiBadRequest(msg) => {
format!("请求参数错误: {}", truncate(&sanitize_html(msg), 150))
}
Self::ApiServerError { status, .. } => {
format!("{} (HTTP {})", http_status_label(*status), status)
}
Self::NetworkTimeout(_) => "网络连接超时,请检查网络后重试".to_string(),
Self::NetworkError(_) => "网络连接失败,请检查网络设置".to_string(),
Self::StreamInterrupted(msg) => {
format!("流式响应中断: {}", truncate(&sanitize_html(msg), 100))
}
Self::StreamDeserialize(msg) => {
format!("响应解析失败: {}", truncate(&sanitize_html(msg), 100))
}
Self::RequestBuild(msg) => {
format!("构建请求失败: {}", truncate(msg, 150))
}
Self::HookAborted => "请求被 hook 中止".to_string(),
Self::RuntimeFailed(msg) => {
format!("运行时错误: {}", truncate(msg, 100))
}
Self::AgentPanic(msg) => {
format!("Agent 异常: {}", truncate(msg, 100))
}
Self::AbnormalFinish(reason) => {
format!("API 返回异常: finish_reason={}", truncate(reason, 80))
}
Self::Other(msg) => truncate(&sanitize_html(msg), 200),
}
}
}
impl std::fmt::Display for ChatError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ApiAuth(msg) => write!(f, "API 认证失败: {}", msg),
Self::ApiRateLimit { message, .. } => write!(f, "API 请求过于频繁: {}", message),
Self::ApiBadRequest(msg) => write!(f, "API 请求参数错误: {}", msg),
Self::ApiServerError { status, message } => {
write!(f, "API 服务端错误 (HTTP {}): {}", status, message)
}
Self::NetworkTimeout(msg) => write!(f, "网络连接超时: {}", msg),
Self::NetworkError(msg) => write!(f, "网络错误: {}", msg),
Self::StreamInterrupted(msg) => write!(f, "流式响应中断: {}", msg),
Self::StreamDeserialize(msg) => write!(f, "流式反序列化失败: {}", msg),
Self::RequestBuild(msg) => write!(f, "构建请求失败: {}", msg),
Self::HookAborted => write!(f, "LLM 请求被 hook 中止"),
Self::RuntimeFailed(msg) => write!(f, "运行时错误: {}", msg),
Self::AgentPanic(msg) => write!(f, "Agent 异常: {}", msg),
Self::AbnormalFinish(reason) => write!(f, "API 返回异常: finish_reason={}", reason),
Self::Other(msg) => write!(f, "{}", msg),
}
}
}
impl std::error::Error for ChatError {}
impl From<OpenAIError> for ChatError {
fn from(e: OpenAIError) -> Self {
match e {
OpenAIError::Reqwest(re) => {
if re.is_timeout() {
ChatError::NetworkTimeout(re.to_string())
} else if let Some(status) = re.status() {
ChatError::from_http_status(status.as_u16(), re.to_string())
} else {
ChatError::NetworkError(re.to_string())
}
}
OpenAIError::ApiError(api_err) => ChatError::from_api_error(api_err),
OpenAIError::JSONDeserialize(_, content) => {
ChatError::StreamDeserialize(truncate(&content, 500))
}
OpenAIError::StreamError(_) => ChatError::StreamInterrupted(e.to_string()),
OpenAIError::InvalidArgument(msg) => ChatError::RequestBuild(msg),
_ => ChatError::Other(e.to_string()),
}
}
}
impl From<reqwest::Error> for ChatError {
fn from(e: reqwest::Error) -> Self {
if e.is_timeout() {
ChatError::NetworkTimeout(e.to_string())
} else if let Some(status) = e.status() {
ChatError::from_http_status(status.as_u16(), e.to_string())
} else {
ChatError::NetworkError(e.to_string())
}
}
}
impl ChatError {
pub(crate) fn from_http_status(status: u16, message: String) -> Self {
match status {
401 | 403 => ChatError::ApiAuth(message),
429 => ChatError::ApiRateLimit {
message,
retry_after_secs: None,
},
400 => ChatError::ApiBadRequest(message),
500..=599 => ChatError::ApiServerError {
status,
message: sanitize_html(&message),
},
_ => ChatError::Other(format!("HTTP {}: {}", status, message)),
}
}
fn from_api_error(api_err: ApiError) -> Self {
match api_err.code.as_deref() {
Some("rate_limit_exceeded") => ChatError::ApiRateLimit {
message: api_err.message.clone(),
retry_after_secs: None,
},
Some("invalid_api_key") | Some("authentication_required") => {
ChatError::ApiAuth(api_err.message.clone())
}
Some("invalid_request_error") => ChatError::ApiBadRequest(api_err.message.clone()),
_ => {
let msg_lower = api_err.message.to_lowercase();
if msg_lower.contains("api key")
|| msg_lower.contains("unauthorized")
|| msg_lower.contains("authentication")
{
ChatError::ApiAuth(api_err.message)
} else if msg_lower.contains("rate limit")
|| msg_lower.contains("too many requests")
{
ChatError::ApiRateLimit {
message: api_err.message,
retry_after_secs: None,
}
} else if msg_lower.contains("invalid") || msg_lower.contains("bad request") {
ChatError::ApiBadRequest(api_err.message)
} else {
ChatError::Other(sanitize_html(&api_err.message))
}
}
}
}
}
fn sanitize_html(input: &str) -> String {
let mut result = String::with_capacity(input.len());
let mut in_tag = false;
for ch in input.chars() {
match ch {
'<' => in_tag = true,
'>' => in_tag = false,
_ if !in_tag => result.push(ch),
_ => {}
}
}
let collapsed: String = result.split_whitespace().collect::<Vec<_>>().join(" ");
collapsed
}
fn truncate(s: &str, max_len: usize) -> String {
if s.len() <= max_len {
return s.to_string();
}
let mut end = max_len;
while !s.is_char_boundary(end) && end > 0 {
end -= 1;
}
format!("{}...", &s[..end])
}
fn http_status_label(status: u16) -> &'static str {
match status {
500 => "服务端内部错误",
502 => "网关错误",
503 => "服务暂不可用",
504 => "网关超时",
529 => "服务过载",
_ => "服务端错误",
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sanitize_html() {
assert_eq!(sanitize_html("<html><body>504</body></html>"), "504");
assert_eq!(
sanitize_html("<html><body>504 Gateway Timeout</body></html>"),
"504 Gateway Timeout"
);
assert_eq!(sanitize_html("plain text"), "plain text");
assert_eq!(sanitize_html("<div>hello</div><p>world</p>"), "hello world");
}
#[test]
fn test_truncate() {
assert_eq!(truncate("hello", 10), "hello");
assert_eq!(truncate("hello world", 8), "hello...");
}
#[test]
fn test_display_message_504() {
let err = ChatError::ApiServerError {
status: 504,
message: "<html><body>504 Gateway Timeout</body></html>".to_string(),
};
assert_eq!(err.display_message(), "网关超时 (HTTP 504)");
}
#[test]
fn test_display_message_auth() {
let err = ChatError::ApiAuth("invalid api key".to_string());
assert_eq!(err.display_message(), "API 认证失败,请检查 API Key");
}
#[test]
fn test_display_message_rate_limit() {
let err = ChatError::ApiRateLimit {
message: "too many requests".to_string(),
retry_after_secs: Some(30),
};
assert_eq!(err.display_message(), "请求过于频繁,请在 30 秒后重试");
}
}