1use std::error::Error as StdError;
7use std::fmt;
8use std::io;
9
10#[derive(Debug)]
12pub struct Error {
13 kind: ErrorKind,
14 message: String,
15 cause: Option<Box<dyn StdError + Send + Sync>>,
16 sqlstate_code: Option<String>,
18 detail: Option<String>,
20 hint: Option<String>,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ErrorKind {
27 Connection,
29 Authentication,
31 Query,
33 Protocol,
35 Io,
37 Config,
39 Timeout,
41 Cancelled,
43 Closed,
45 Conversion,
47 FeatureNotSupported,
49 Other,
51}
52
53impl Error {
54 pub fn new(kind: ErrorKind, message: impl Into<String>) -> Self {
56 Error {
57 kind,
58 message: message.into(),
59 cause: None,
60 sqlstate_code: None,
61 detail: None,
62 hint: None,
63 }
64 }
65
66 pub fn with_cause<E>(kind: ErrorKind, message: impl Into<String>, cause: E) -> Self
68 where
69 E: Into<Box<dyn StdError + Send + Sync>>,
70 {
71 Error {
72 kind,
73 message: message.into(),
74 cause: Some(cause.into()),
75 sqlstate_code: None,
76 detail: None,
77 hint: None,
78 }
79 }
80
81 pub fn new_with_details(
85 kind: ErrorKind,
86 message: impl Into<String>,
87 detail: Option<String>,
88 hint: Option<String>,
89 sqlstate: Option<String>,
90 ) -> Self {
91 Error {
92 kind,
93 message: message.into(),
94 cause: None,
95 sqlstate_code: sqlstate,
96 detail,
97 hint,
98 }
99 }
100
101 #[must_use]
103 pub fn kind(&self) -> ErrorKind {
104 self.kind
105 }
106
107 #[must_use]
109 pub fn message(&self) -> &str {
110 &self.message
111 }
112
113 #[must_use]
115 pub fn detail(&self) -> Option<&str> {
116 self.detail.as_deref()
117 }
118
119 #[must_use]
121 pub fn hint(&self) -> Option<&str> {
122 self.hint.as_deref()
123 }
124
125 pub fn connection(message: impl Into<String>) -> Self {
129 Self::new(ErrorKind::Connection, message)
130 }
131
132 pub fn authentication(message: impl Into<String>) -> Self {
134 Self::new(ErrorKind::Authentication, message)
135 }
136
137 pub fn query(message: impl Into<String>) -> Self {
139 Self::new(ErrorKind::Query, message)
140 }
141
142 pub fn protocol(message: impl Into<String>) -> Self {
144 Self::new(ErrorKind::Protocol, message)
145 }
146
147 #[must_use]
149 pub fn closed() -> Self {
150 Self::new(ErrorKind::Closed, "connection closed")
151 }
152
153 #[must_use]
155 pub fn timeout() -> Self {
156 Self::new(ErrorKind::Timeout, "operation timed out")
157 }
158
159 #[must_use]
161 pub fn io(err: io::Error) -> Self {
162 Self::with_cause(ErrorKind::Io, err.to_string(), err)
163 }
164
165 #[must_use]
167 pub fn db(severity: &str, code: &str, message: &str) -> Self {
168 Error {
169 kind: ErrorKind::Query,
170 message: format!("{severity}: {message} ({code})"),
171 cause: None,
172 sqlstate_code: Some(code.to_string()),
173 detail: None,
174 hint: None,
175 }
176 }
177
178 pub fn feature_not_supported(message: impl Into<String>) -> Self {
183 Self::new(ErrorKind::FeatureNotSupported, message)
184 }
185
186 pub fn other(message: impl Into<String>) -> Self {
188 Self::new(ErrorKind::Other, message)
189 }
190
191 #[must_use]
205 pub fn sqlstate(&self) -> Option<&str> {
206 if let Some(ref code) = self.sqlstate_code {
208 return Some(code);
209 }
210 if self.kind == ErrorKind::Query {
212 extract_sqlstate(&self.message)
213 } else {
214 None
215 }
216 }
217}
218
219fn extract_sqlstate(message: &str) -> Option<&str> {
224 let start = message.rfind('(')?;
226 let end = message[start..].find(')')?;
227
228 let code = message[start + 1..start + end].trim();
229
230 if code.len() == 5 && code.chars().all(|c| c.is_ascii_alphanumeric()) {
232 Some(code)
233 } else {
234 None
235 }
236}
237
238impl fmt::Display for Error {
239 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
240 write!(f, "{}", self.message)?;
241 if let Some(ref detail) = self.detail {
242 if !self.message.contains(detail) {
243 write!(f, ": {detail}")?;
244 }
245 }
246 if let Some(ref cause) = self.cause {
247 write!(f, ": {cause}")?;
248 }
249 Ok(())
250 }
251}
252
253impl StdError for Error {
254 fn source(&self) -> Option<&(dyn StdError + 'static)> {
255 self.cause.as_ref().map(|e| &**e as &dyn std::error::Error)
256 }
257}
258
259impl From<io::Error> for Error {
260 fn from(err: io::Error) -> Self {
261 Error::io(err)
262 }
263}
264
265pub type Result<T> = std::result::Result<T, Error>;
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
273 fn test_sqlstate_extraction() {
274 let err = Error::db("ERROR", "42P04", "database \"test\" already exists");
276 assert_eq!(err.sqlstate(), Some("42P04"));
277
278 let err = Error::db("ERROR", "42710", "duplicate object");
280 assert_eq!(err.sqlstate(), Some("42710"));
281
282 let err = Error::db("ERROR", "42P06", "schema \"public\" already exists");
284 assert_eq!(err.sqlstate(), Some("42P06"));
285
286 let err = Error::db("ERROR", "42P07", "table \"users\" already exists");
288 assert_eq!(err.sqlstate(), Some("42P07"));
289 }
290
291 #[test]
292 fn test_sqlstate_non_query_error() {
293 let err = Error::connection("connection failed");
295 assert_eq!(err.sqlstate(), None);
296
297 let err = Error::timeout();
298 assert_eq!(err.sqlstate(), None);
299 }
300
301 #[test]
302 fn test_extract_sqlstate_edge_cases() {
303 assert_eq!(extract_sqlstate("ERROR: message (42P04)"), Some("42P04"));
305
306 assert_eq!(extract_sqlstate("ERROR: message ( 42P04 )"), Some("42P04"));
308
309 assert_eq!(
311 extract_sqlstate("ERROR: (extra info) message (42P04)"),
312 Some("42P04")
313 );
314
315 assert_eq!(extract_sqlstate("ERROR: message (42P)"), None);
317
318 assert_eq!(extract_sqlstate("ERROR: message (42P044)"), None);
320
321 assert_eq!(extract_sqlstate("ERROR: message (42-04)"), None);
323
324 assert_eq!(extract_sqlstate("ERROR: message"), None);
326
327 assert_eq!(extract_sqlstate("ERROR: message ()"), None);
329 }
330}