1use std::fmt::{Display, Formatter};
2
3use serde_json::{Map, Value};
4
5use crate::ai::constants::AI_TYPE;
6
7#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
12pub enum AiErrorCode {
13 Error,
14 RequestError,
15 ResponseError,
16 FetchError,
17 SessionClosed,
18 InvalidContent,
19 ApiNotEnabled,
20 InvalidSchema,
21 NoApiKey,
22 NoAppId,
23 NoModel,
24 NoProjectId,
25 ParseFailed,
26 Unsupported,
27 InvalidArgument,
28 Internal,
29}
30
31impl AiErrorCode {
32 pub fn as_code(&self) -> &'static str {
34 match self {
35 AiErrorCode::Error => "error",
36 AiErrorCode::RequestError => "request-error",
37 AiErrorCode::ResponseError => "response-error",
38 AiErrorCode::FetchError => "fetch-error",
39 AiErrorCode::SessionClosed => "session-closed",
40 AiErrorCode::InvalidContent => "invalid-content",
41 AiErrorCode::ApiNotEnabled => "api-not-enabled",
42 AiErrorCode::InvalidSchema => "invalid-schema",
43 AiErrorCode::NoApiKey => "no-api-key",
44 AiErrorCode::NoAppId => "no-app-id",
45 AiErrorCode::NoModel => "no-model",
46 AiErrorCode::NoProjectId => "no-project-id",
47 AiErrorCode::ParseFailed => "parse-failed",
48 AiErrorCode::Unsupported => "unsupported",
49 AiErrorCode::InvalidArgument => "invalid-argument",
50 AiErrorCode::Internal => "internal",
51 }
52 }
53}
54
55#[derive(Clone, Debug, PartialEq)]
59pub struct ErrorDetails {
60 pub type_url: Option<String>,
61 pub reason: Option<String>,
62 pub domain: Option<String>,
63 pub metadata: Option<Map<String, Value>>,
64}
65
66impl Default for ErrorDetails {
67 fn default() -> Self {
68 Self {
69 type_url: None,
70 reason: None,
71 domain: None,
72 metadata: None,
73 }
74 }
75}
76
77#[derive(Clone, Debug, Default, PartialEq)]
81pub struct CustomErrorData {
82 pub status: Option<u16>,
83 pub status_text: Option<String>,
84 pub response: Option<Value>,
85 pub error_details: Vec<ErrorDetails>,
86}
87
88impl CustomErrorData {
89 pub fn with_status(mut self, status: u16) -> Self {
90 self.status = Some(status);
91 self
92 }
93
94 pub fn with_status_text<S: Into<String>>(mut self, status_text: S) -> Self {
95 self.status_text = Some(status_text.into());
96 self
97 }
98
99 pub fn with_response(mut self, response: Value) -> Self {
100 self.response = Some(response);
101 self
102 }
103
104 pub fn with_error_details(mut self, details: Vec<ErrorDetails>) -> Self {
105 self.error_details = details;
106 self
107 }
108}
109
110#[derive(Clone, Debug)]
111pub struct AiError {
112 pub code: AiErrorCode,
113 message: String,
114 pub custom_error_data: Option<CustomErrorData>,
115}
116
117impl AiError {
118 pub fn new(
120 code: AiErrorCode,
121 message: impl Into<String>,
122 custom_error_data: Option<CustomErrorData>,
123 ) -> Self {
124 Self {
125 code,
126 message: message.into(),
127 custom_error_data,
128 }
129 }
130
131 pub fn code(&self) -> AiErrorCode {
133 self.code
134 }
135
136 pub fn code_str(&self) -> String {
138 format!("{}/{}", AI_TYPE, self.code.as_code())
139 }
140
141 pub fn message(&self) -> &str {
143 &self.message
144 }
145}
146
147impl Display for AiError {
148 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
149 write!(f, "{}: {} ({})", AI_TYPE, self.message, self.code_str())
150 }
151}
152
153impl std::error::Error for AiError {}
154
155pub type AiResult<T> = Result<T, AiError>;
156
157pub fn invalid_argument(message: impl Into<String>) -> AiError {
158 AiError::new(AiErrorCode::InvalidArgument, message, None)
159}
160
161pub fn internal_error(message: impl Into<String>) -> AiError {
162 AiError::new(AiErrorCode::Internal, message, None)
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 #[test]
170 fn formats_full_code() {
171 let err = invalid_argument("Bad input");
172 assert_eq!(err.code(), AiErrorCode::InvalidArgument);
173 assert_eq!(err.code_str(), "AI/invalid-argument");
174 assert_eq!(err.message(), "Bad input");
175 assert_eq!(format!("{}", err), "AI: Bad input (AI/invalid-argument)");
176 }
177
178 #[test]
179 fn supports_custom_data_builders() {
180 let details = ErrorDetails {
181 type_url: Some("type.googleapis.com/google.rpc.ErrorInfo".into()),
182 reason: Some("RATE_LIMIT".into()),
183 domain: None,
184 metadata: Some(Map::new()),
185 };
186 let data = CustomErrorData::default()
187 .with_status(429)
188 .with_status_text("Too Many Requests")
189 .with_error_details(vec![details.clone()]);
190 let err = AiError::new(
191 AiErrorCode::FetchError,
192 "quota exceeded",
193 Some(data.clone()),
194 );
195 assert_eq!(err.custom_error_data, Some(data));
196 assert!(err.code_str().contains("fetch-error"));
197 }
198}