Skip to main content

br_pgsql/
error.rs

1use std::fmt;
2use std::io;
3
4#[derive(Debug)]
5pub enum PgsqlError {
6    Connection(String),
7    Io(io::Error),
8    Timeout(String),
9    Protocol(String),
10    Auth(String),
11    Query {
12        code: String,
13        message: String,
14        detail: String,
15        sql: String,
16        position: u16,
17    },
18    Pool(String),
19}
20
21impl fmt::Display for PgsqlError {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        match self {
24            PgsqlError::Connection(msg) => write!(f, "连接错误: {}", msg),
25            PgsqlError::Io(err) => write!(f, "IO错误: {}", err),
26            PgsqlError::Timeout(msg) => write!(f, "超时: {}", msg),
27            PgsqlError::Protocol(msg) => write!(f, "协议错误: {}", msg),
28            PgsqlError::Auth(msg) => write!(f, "认证失败: {}", msg),
29            PgsqlError::Query {
30                code,
31                message,
32                detail,
33                sql,
34                position,
35            } => {
36                write!(
37                    f,
38                    "Code: {} ErrorMsg[line:{}]: {} detail: {} SQL: {}",
39                    code, position, message, detail, sql
40                )
41            }
42            PgsqlError::Pool(msg) => write!(f, "连接池错误: {}", msg),
43        }
44    }
45}
46
47impl std::error::Error for PgsqlError {
48    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
49        match self {
50            PgsqlError::Io(err) => Some(err),
51            _ => None,
52        }
53    }
54}
55
56impl From<io::Error> for PgsqlError {
57    fn from(err: io::Error) -> Self {
58        PgsqlError::Io(err)
59    }
60}
61
62impl From<String> for PgsqlError {
63    fn from(s: String) -> Self {
64        PgsqlError::Protocol(s)
65    }
66}
67
68impl From<&str> for PgsqlError {
69    fn from(s: &str) -> Self {
70        PgsqlError::Protocol(s.to_string())
71    }
72}
73
74impl From<PgsqlError> for String {
75    fn from(err: PgsqlError) -> Self {
76        err.to_string()
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::PgsqlError;
83    use std::error::Error;
84    use std::io;
85
86    fn sample_query_error() -> PgsqlError {
87        PgsqlError::Query {
88            code: "23505".to_string(),
89            message: "duplicate key".to_string(),
90            detail: "key exists".to_string(),
91            sql: "INSERT INTO t VALUES (1)".to_string(),
92            position: 12,
93        }
94    }
95
96    #[test]
97    fn display_formats_all_variants() {
98        let connection = PgsqlError::Connection("network down".to_string());
99        assert_eq!(connection.to_string(), "连接错误: network down");
100
101        let io_err = PgsqlError::Io(io::Error::new(io::ErrorKind::BrokenPipe, "pipe closed"));
102        assert_eq!(io_err.to_string(), "IO错误: pipe closed");
103
104        let timeout = PgsqlError::Timeout("request timeout".to_string());
105        assert_eq!(timeout.to_string(), "超时: request timeout");
106
107        let protocol = PgsqlError::Protocol("invalid packet".to_string());
108        assert_eq!(protocol.to_string(), "协议错误: invalid packet");
109
110        let auth = PgsqlError::Auth("wrong password".to_string());
111        assert_eq!(auth.to_string(), "认证失败: wrong password");
112
113        let query = sample_query_error();
114        assert_eq!(
115            query.to_string(),
116            "Code: 23505 ErrorMsg[line:12]: duplicate key detail: key exists SQL: INSERT INTO t VALUES (1)"
117        );
118
119        let pool = PgsqlError::Pool("pool exhausted".to_string());
120        assert_eq!(pool.to_string(), "连接池错误: pool exhausted");
121    }
122
123    #[test]
124    fn source_returns_some_for_io_variant() {
125        let io_err = PgsqlError::Io(io::Error::new(io::ErrorKind::TimedOut, "socket timeout"));
126
127        let source = io_err.source().expect("Io variant should expose source");
128        assert_eq!(source.to_string(), "socket timeout");
129    }
130
131    #[test]
132    fn source_returns_none_for_non_io_variants() {
133        let errs = vec![
134            PgsqlError::Connection("c".to_string()),
135            PgsqlError::Timeout("t".to_string()),
136            PgsqlError::Protocol("p".to_string()),
137            PgsqlError::Auth("a".to_string()),
138            sample_query_error(),
139            PgsqlError::Pool("pool".to_string()),
140        ];
141
142        for err in errs {
143            assert!(err.source().is_none());
144        }
145    }
146
147    #[test]
148    fn from_io_error_creates_io_variant() {
149        let err = io::Error::new(io::ErrorKind::NotFound, "missing file");
150        let pg_err: PgsqlError = err.into();
151
152        match pg_err {
153            PgsqlError::Io(inner) => {
154                assert_eq!(inner.kind(), io::ErrorKind::NotFound);
155                assert_eq!(inner.to_string(), "missing file");
156            }
157            other => panic!("expected Io variant, got {other:?}"),
158        }
159    }
160
161    #[test]
162    fn from_string_creates_protocol_variant() {
163        let pg_err: PgsqlError = String::from("bad response").into();
164
165        match pg_err {
166            PgsqlError::Protocol(msg) => assert_eq!(msg, "bad response"),
167            other => panic!("expected Protocol variant, got {other:?}"),
168        }
169    }
170
171    #[test]
172    fn from_str_creates_protocol_variant() {
173        let pg_err: PgsqlError = "decode error".into();
174
175        match pg_err {
176            PgsqlError::Protocol(msg) => assert_eq!(msg, "decode error"),
177            other => panic!("expected Protocol variant, got {other:?}"),
178        }
179    }
180
181    #[test]
182    fn from_pgsql_error_to_string_uses_display_output() {
183        let err = PgsqlError::Auth("login failed".to_string());
184        let value: String = err.into();
185
186        assert_eq!(value, "认证失败: login failed");
187    }
188
189    #[test]
190    fn debug_derive_formats_variant_details() {
191        let err = sample_query_error();
192        let debug = format!("{:?}", err);
193
194        assert!(debug.contains("Query"));
195        assert!(debug.contains("23505"));
196        assert!(debug.contains("duplicate key"));
197        assert!(debug.contains("position: 12"));
198    }
199}