use std::collections::BTreeMap;
use std::fmt;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
use crate::json_payload::JsonPayload;
use crate::providers::ProviderKind;
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug, Error)]
pub enum Error {
#[error("客户端配置无效: {0}")]
InvalidConfig(String),
#[error("请求缺少必填字段: {field}")]
MissingRequiredField {
field: &'static str,
},
#[error("缺少 API 凭证")]
MissingCredentials,
#[error(transparent)]
Api(#[from] ApiError),
#[error(transparent)]
Connection(#[from] ConnectionError),
#[error("请求超时")]
Timeout,
#[error(transparent)]
Stream(#[from] StreamError),
#[error(transparent)]
WebSocket(#[from] WebSocketError),
#[error(transparent)]
Serialization(#[from] SerializationError),
#[error(transparent)]
LengthFinishReason(#[from] LengthFinishReasonError),
#[error(transparent)]
ContentFilterFinishReason(#[from] ContentFilterFinishReasonError),
#[error(transparent)]
WebhookVerification(#[from] WebhookVerificationError),
#[error(transparent)]
ProviderCompatibility(#[from] ProviderCompatibilityError),
#[error("请求已取消")]
Cancelled,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ApiErrorKind {
BadRequest,
Authentication,
PermissionDenied,
NotFound,
Conflict,
UnprocessableEntity,
RateLimit,
InternalServer,
Unknown,
}
impl ApiErrorKind {
pub fn from_status(status: u16) -> Self {
match status {
400 => Self::BadRequest,
401 => Self::Authentication,
403 => Self::PermissionDenied,
404 => Self::NotFound,
409 => Self::Conflict,
422 => Self::UnprocessableEntity,
429 => Self::RateLimit,
500..=599 => Self::InternalServer,
_ => Self::Unknown,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiError {
pub status: u16,
pub kind: ApiErrorKind,
pub message: String,
pub request_id: Option<String>,
pub provider: ProviderKind,
pub raw: Option<JsonPayload>,
}
impl ApiError {
pub fn new(
status: u16,
message: impl Into<String>,
request_id: Option<String>,
provider: ProviderKind,
raw: Option<JsonPayload>,
) -> Self {
Self {
status,
kind: ApiErrorKind::from_status(status),
message: message.into(),
request_id,
provider,
raw,
}
}
}
impl fmt::Display for ApiError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} (status {})", self.message, self.status)
}
}
impl std::error::Error for ApiError {}
#[derive(Debug, Error, Clone)]
#[error("{message}")]
pub struct ConnectionError {
pub message: String,
}
impl ConnectionError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
#[derive(Debug, Error, Clone)]
#[error("{message}")]
pub struct SerializationError {
pub message: String,
}
impl SerializationError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
#[derive(Debug, Error, Clone)]
#[error("{message}")]
pub struct StreamError {
pub message: String,
}
#[derive(Debug, Error, Clone)]
#[error("无法继续解析响应内容: 模型因长度上限提前结束")]
pub struct LengthFinishReasonError;
#[derive(Debug, Error, Clone)]
#[error("无法继续解析响应内容: 请求被内容过滤器拦截")]
pub struct ContentFilterFinishReasonError;
impl StreamError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum WebSocketErrorKind {
Transport,
Protocol,
Server,
}
#[derive(Debug, Error, Clone)]
#[error("{message}")]
pub struct WebSocketError {
pub kind: WebSocketErrorKind,
pub message: String,
pub event_type: Option<String>,
}
impl WebSocketError {
pub fn new(message: impl Into<String>) -> Self {
Self {
kind: WebSocketErrorKind::Protocol,
message: message.into(),
event_type: None,
}
}
pub fn transport(message: impl Into<String>) -> Self {
Self {
kind: WebSocketErrorKind::Transport,
message: message.into(),
event_type: None,
}
}
pub fn protocol(message: impl Into<String>) -> Self {
Self {
kind: WebSocketErrorKind::Protocol,
message: message.into(),
event_type: None,
}
}
pub fn server(message: impl Into<String>, event_type: Option<String>) -> Self {
Self {
kind: WebSocketErrorKind::Server,
message: message.into(),
event_type,
}
}
}
#[derive(Debug, Error, Clone)]
#[error("{message}")]
pub struct WebhookVerificationError {
pub message: String,
}
impl WebhookVerificationError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
#[derive(Debug, Error, Clone)]
#[error("{message}")]
pub struct ProviderCompatibilityError {
pub message: String,
pub provider: ProviderKind,
}
impl ProviderCompatibilityError {
pub fn new(provider: ProviderKind, message: impl Into<String>) -> Self {
Self {
message: message.into(),
provider,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ErrorBody {
pub message: Option<String>,
#[serde(rename = "type")]
pub error_type: Option<String>,
pub param: Option<String>,
pub code: Option<JsonPayload>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}