1use std::fmt;
2use std::io;
3
4#[derive(Debug)]
5pub enum PgsqlError {
6 Connection(String),
7 Io(io::Error),
8 Timeout(String),
9 Protocol(String),
10 Auth(String),
11 Query {
12 code: String,
13 message: String,
14 detail: String,
15 sql: String,
16 position: u16,
17 },
18 Pool(String),
19 Config(String),
20}
21
22impl fmt::Display for PgsqlError {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 match self {
25 PgsqlError::Connection(msg) => write!(f, "连接错误: {}", msg),
26 PgsqlError::Io(err) => write!(f, "IO错误: {}", err),
27 PgsqlError::Timeout(msg) => write!(f, "超时: {}", msg),
28 PgsqlError::Protocol(msg) => write!(f, "协议错误: {}", msg),
29 PgsqlError::Auth(msg) => write!(f, "认证失败: {}", msg),
30 PgsqlError::Query {
31 code,
32 message,
33 detail,
34 sql,
35 position,
36 } => {
37 write!(
38 f,
39 "Code: {} ErrorMsg[line:{}]: {} detail: {} SQL: {}",
40 code, position, message, detail, sql
41 )
42 }
43 PgsqlError::Pool(msg) => write!(f, "连接池错误: {}", msg),
44 PgsqlError::Config(msg) => write!(f, "Config error: {}", msg),
45 }
46 }
47}
48
49impl std::error::Error for PgsqlError {
50 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
51 match self {
52 PgsqlError::Io(err) => Some(err),
53 _ => None,
54 }
55 }
56}
57
58impl From<io::Error> for PgsqlError {
59 fn from(err: io::Error) -> Self {
60 PgsqlError::Io(err)
61 }
62}
63
64impl From<String> for PgsqlError {
65 fn from(s: String) -> Self {
66 PgsqlError::Protocol(s)
67 }
68}
69
70impl From<&str> for PgsqlError {
71 fn from(s: &str) -> Self {
72 PgsqlError::Protocol(s.to_string())
73 }
74}
75
76impl From<PgsqlError> for String {
77 fn from(err: PgsqlError) -> Self {
78 err.to_string()
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::PgsqlError;
85 use std::error::Error;
86 use std::io;
87
88 fn sample_query_error() -> PgsqlError {
89 PgsqlError::Query {
90 code: "23505".to_string(),
91 message: "duplicate key".to_string(),
92 detail: "key exists".to_string(),
93 sql: "INSERT INTO t VALUES (1)".to_string(),
94 position: 12,
95 }
96 }
97
98 #[test]
99 fn display_formats_all_variants() {
100 let connection = PgsqlError::Connection("network down".to_string());
101 assert_eq!(connection.to_string(), "连接错误: network down");
102
103 let io_err = PgsqlError::Io(io::Error::new(io::ErrorKind::BrokenPipe, "pipe closed"));
104 assert_eq!(io_err.to_string(), "IO错误: pipe closed");
105
106 let timeout = PgsqlError::Timeout("request timeout".to_string());
107 assert_eq!(timeout.to_string(), "超时: request timeout");
108
109 let protocol = PgsqlError::Protocol("invalid packet".to_string());
110 assert_eq!(protocol.to_string(), "协议错误: invalid packet");
111
112 let auth = PgsqlError::Auth("wrong password".to_string());
113 assert_eq!(auth.to_string(), "认证失败: wrong password");
114
115 let query = sample_query_error();
116 assert_eq!(
117 query.to_string(),
118 "Code: 23505 ErrorMsg[line:12]: duplicate key detail: key exists SQL: INSERT INTO t VALUES (1)"
119 );
120
121 let pool = PgsqlError::Pool("pool exhausted".to_string());
122 assert_eq!(pool.to_string(), "连接池错误: pool exhausted");
123 let config_err = PgsqlError::Config("bad url".to_string());
124 assert_eq!(config_err.to_string(), "Config error: bad url");
125 }
126 #[test]
127 fn source_returns_some_for_io_variant() {
128 let io_err = PgsqlError::Io(io::Error::new(io::ErrorKind::TimedOut, "socket timeout"));
129
130 let source = io_err.source().expect("Io variant should expose source");
131 assert_eq!(source.to_string(), "socket timeout");
132 }
133
134 #[test]
135 fn source_returns_none_for_non_io_variants() {
136 let errs = vec![
137 PgsqlError::Connection("c".to_string()),
138 PgsqlError::Timeout("t".to_string()),
139 PgsqlError::Protocol("p".to_string()),
140 PgsqlError::Auth("a".to_string()),
141 sample_query_error(),
142 PgsqlError::Pool("pool".to_string()),
143 ];
144
145 for err in errs {
146 assert!(err.source().is_none());
147 }
148 }
149
150 #[test]
151 fn from_io_error_creates_io_variant() {
152 let err = io::Error::new(io::ErrorKind::NotFound, "missing file");
153 let pg_err: PgsqlError = err.into();
154
155 match pg_err {
156 PgsqlError::Io(inner) => {
157 assert_eq!(inner.kind(), io::ErrorKind::NotFound);
158 assert_eq!(inner.to_string(), "missing file");
159 }
160 other => panic!("expected Io variant, got {other:?}"),
161 }
162 }
163
164 #[test]
165 fn from_string_creates_protocol_variant() {
166 let pg_err: PgsqlError = String::from("bad response").into();
167
168 match pg_err {
169 PgsqlError::Protocol(msg) => assert_eq!(msg, "bad response"),
170 other => panic!("expected Protocol variant, got {other:?}"),
171 }
172 }
173
174 #[test]
175 fn from_str_creates_protocol_variant() {
176 let pg_err: PgsqlError = "decode error".into();
177
178 match pg_err {
179 PgsqlError::Protocol(msg) => assert_eq!(msg, "decode error"),
180 other => panic!("expected Protocol variant, got {other:?}"),
181 }
182 }
183
184 #[test]
185 fn from_pgsql_error_to_string_uses_display_output() {
186 let err = PgsqlError::Auth("login failed".to_string());
187 let value: String = err.into();
188
189 assert_eq!(value, "认证失败: login failed");
190 }
191
192 #[test]
193 fn debug_derive_formats_variant_details() {
194 let err = sample_query_error();
195 let debug = format!("{:?}", err);
196
197 assert!(debug.contains("Query"));
198 assert!(debug.contains("23505"));
199 assert!(debug.contains("duplicate key"));
200 assert!(debug.contains("position: 12"));
201 }
202}