Skip to main content

br_pgsql/
error.rs

1use std::fmt;
2use std::io;
3
4#[derive(Debug)]
5pub enum PgsqlError {
6    Connection(String),
7    Io(io::Error),
8    Timeout(String),
9    Protocol(String),
10    Auth(String),
11    Query {
12        code: String,
13        message: String,
14        detail: String,
15        sql: String,
16        position: u16,
17    },
18    Pool(String),
19    Config(String),
20}
21
22impl fmt::Display for PgsqlError {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        match self {
25            PgsqlError::Connection(msg) => write!(f, "连接错误: {}", msg),
26            PgsqlError::Io(err) => write!(f, "IO错误: {}", err),
27            PgsqlError::Timeout(msg) => write!(f, "超时: {}", msg),
28            PgsqlError::Protocol(msg) => write!(f, "协议错误: {}", msg),
29            PgsqlError::Auth(msg) => write!(f, "认证失败: {}", msg),
30            PgsqlError::Query {
31                code,
32                message,
33                detail,
34                sql,
35                position,
36            } => {
37                write!(
38                    f,
39                    "Code: {} ErrorMsg[line:{}]: {} detail: {} SQL: {}",
40                    code, position, message, detail, sql
41                )
42            }
43            PgsqlError::Pool(msg) => write!(f, "连接池错误: {}", msg),
44            PgsqlError::Config(msg) => write!(f, "Config error: {}", msg),
45        }
46    }
47}
48
49impl std::error::Error for PgsqlError {
50    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
51        match self {
52            PgsqlError::Io(err) => Some(err),
53            _ => None,
54        }
55    }
56}
57
58impl From<io::Error> for PgsqlError {
59    fn from(err: io::Error) -> Self {
60        PgsqlError::Io(err)
61    }
62}
63
64impl From<String> for PgsqlError {
65    fn from(s: String) -> Self {
66        PgsqlError::Protocol(s)
67    }
68}
69
70impl From<&str> for PgsqlError {
71    fn from(s: &str) -> Self {
72        PgsqlError::Protocol(s.to_string())
73    }
74}
75
76impl From<PgsqlError> for String {
77    fn from(err: PgsqlError) -> Self {
78        err.to_string()
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::PgsqlError;
85    use std::error::Error;
86    use std::io;
87
88    fn sample_query_error() -> PgsqlError {
89        PgsqlError::Query {
90            code: "23505".to_string(),
91            message: "duplicate key".to_string(),
92            detail: "key exists".to_string(),
93            sql: "INSERT INTO t VALUES (1)".to_string(),
94            position: 12,
95        }
96    }
97
98    #[test]
99    fn display_formats_all_variants() {
100        let connection = PgsqlError::Connection("network down".to_string());
101        assert_eq!(connection.to_string(), "连接错误: network down");
102
103        let io_err = PgsqlError::Io(io::Error::new(io::ErrorKind::BrokenPipe, "pipe closed"));
104        assert_eq!(io_err.to_string(), "IO错误: pipe closed");
105
106        let timeout = PgsqlError::Timeout("request timeout".to_string());
107        assert_eq!(timeout.to_string(), "超时: request timeout");
108
109        let protocol = PgsqlError::Protocol("invalid packet".to_string());
110        assert_eq!(protocol.to_string(), "协议错误: invalid packet");
111
112        let auth = PgsqlError::Auth("wrong password".to_string());
113        assert_eq!(auth.to_string(), "认证失败: wrong password");
114
115        let query = sample_query_error();
116        assert_eq!(
117            query.to_string(),
118            "Code: 23505 ErrorMsg[line:12]: duplicate key detail: key exists SQL: INSERT INTO t VALUES (1)"
119        );
120
121        let pool = PgsqlError::Pool("pool exhausted".to_string());
122        assert_eq!(pool.to_string(), "连接池错误: pool exhausted");
123        let config_err = PgsqlError::Config("bad url".to_string());
124        assert_eq!(config_err.to_string(), "Config error: bad url");
125    }
126    #[test]
127    fn source_returns_some_for_io_variant() {
128        let io_err = PgsqlError::Io(io::Error::new(io::ErrorKind::TimedOut, "socket timeout"));
129
130        let source = io_err.source().expect("Io variant should expose source");
131        assert_eq!(source.to_string(), "socket timeout");
132    }
133
134    #[test]
135    fn source_returns_none_for_non_io_variants() {
136        let errs = vec![
137            PgsqlError::Connection("c".to_string()),
138            PgsqlError::Timeout("t".to_string()),
139            PgsqlError::Protocol("p".to_string()),
140            PgsqlError::Auth("a".to_string()),
141            sample_query_error(),
142            PgsqlError::Pool("pool".to_string()),
143        ];
144
145        for err in errs {
146            assert!(err.source().is_none());
147        }
148    }
149
150    #[test]
151    fn from_io_error_creates_io_variant() {
152        let err = io::Error::new(io::ErrorKind::NotFound, "missing file");
153        let pg_err: PgsqlError = err.into();
154
155        match pg_err {
156            PgsqlError::Io(inner) => {
157                assert_eq!(inner.kind(), io::ErrorKind::NotFound);
158                assert_eq!(inner.to_string(), "missing file");
159            }
160            other => panic!("expected Io variant, got {other:?}"),
161        }
162    }
163
164    #[test]
165    fn from_string_creates_protocol_variant() {
166        let pg_err: PgsqlError = String::from("bad response").into();
167
168        match pg_err {
169            PgsqlError::Protocol(msg) => assert_eq!(msg, "bad response"),
170            other => panic!("expected Protocol variant, got {other:?}"),
171        }
172    }
173
174    #[test]
175    fn from_str_creates_protocol_variant() {
176        let pg_err: PgsqlError = "decode error".into();
177
178        match pg_err {
179            PgsqlError::Protocol(msg) => assert_eq!(msg, "decode error"),
180            other => panic!("expected Protocol variant, got {other:?}"),
181        }
182    }
183
184    #[test]
185    fn from_pgsql_error_to_string_uses_display_output() {
186        let err = PgsqlError::Auth("login failed".to_string());
187        let value: String = err.into();
188
189        assert_eq!(value, "认证失败: login failed");
190    }
191
192    #[test]
193    fn debug_derive_formats_variant_details() {
194        let err = sample_query_error();
195        let debug = format!("{:?}", err);
196
197        assert!(debug.contains("Query"));
198        assert!(debug.contains("23505"));
199        assert!(debug.contains("duplicate key"));
200        assert!(debug.contains("position: 12"));
201    }
202}