use std::fmt;
use std::io;
#[derive(Debug)]
pub enum PgsqlError {
Connection(String),
Io(io::Error),
Timeout(String),
Protocol(String),
Auth(String),
Query {
code: String,
message: String,
detail: String,
sql: String,
position: u16,
},
Pool(String),
Config(String),
}
impl fmt::Display for PgsqlError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PgsqlError::Connection(msg) => write!(f, "连接错误: {}", msg),
PgsqlError::Io(err) => write!(f, "IO错误: {}", err),
PgsqlError::Timeout(msg) => write!(f, "超时: {}", msg),
PgsqlError::Protocol(msg) => write!(f, "协议错误: {}", msg),
PgsqlError::Auth(msg) => write!(f, "认证失败: {}", msg),
PgsqlError::Query {
code,
message,
detail,
sql,
position,
} => {
write!(
f,
"Code: {} ErrorMsg[line:{}]: {} detail: {} SQL: {}",
code, position, message, detail, sql
)
}
PgsqlError::Pool(msg) => write!(f, "连接池错误: {}", msg),
PgsqlError::Config(msg) => write!(f, "Config error: {}", msg),
}
}
}
impl std::error::Error for PgsqlError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
PgsqlError::Io(err) => Some(err),
_ => None,
}
}
}
impl From<io::Error> for PgsqlError {
fn from(err: io::Error) -> Self {
PgsqlError::Io(err)
}
}
impl From<String> for PgsqlError {
fn from(s: String) -> Self {
PgsqlError::Protocol(s)
}
}
impl From<&str> for PgsqlError {
fn from(s: &str) -> Self {
PgsqlError::Protocol(s.to_string())
}
}
impl From<PgsqlError> for String {
fn from(err: PgsqlError) -> Self {
err.to_string()
}
}
#[cfg(test)]
mod tests {
use super::PgsqlError;
use std::error::Error;
use std::io;
fn sample_query_error() -> PgsqlError {
PgsqlError::Query {
code: "23505".to_string(),
message: "duplicate key".to_string(),
detail: "key exists".to_string(),
sql: "INSERT INTO t VALUES (1)".to_string(),
position: 12,
}
}
#[test]
fn display_formats_all_variants() {
let connection = PgsqlError::Connection("network down".to_string());
assert_eq!(connection.to_string(), "连接错误: network down");
let io_err = PgsqlError::Io(io::Error::new(io::ErrorKind::BrokenPipe, "pipe closed"));
assert_eq!(io_err.to_string(), "IO错误: pipe closed");
let timeout = PgsqlError::Timeout("request timeout".to_string());
assert_eq!(timeout.to_string(), "超时: request timeout");
let protocol = PgsqlError::Protocol("invalid packet".to_string());
assert_eq!(protocol.to_string(), "协议错误: invalid packet");
let auth = PgsqlError::Auth("wrong password".to_string());
assert_eq!(auth.to_string(), "认证失败: wrong password");
let query = sample_query_error();
assert_eq!(
query.to_string(),
"Code: 23505 ErrorMsg[line:12]: duplicate key detail: key exists SQL: INSERT INTO t VALUES (1)"
);
let pool = PgsqlError::Pool("pool exhausted".to_string());
assert_eq!(pool.to_string(), "连接池错误: pool exhausted");
let config_err = PgsqlError::Config("bad url".to_string());
assert_eq!(config_err.to_string(), "Config error: bad url");
}
#[test]
fn source_returns_some_for_io_variant() {
let io_err = PgsqlError::Io(io::Error::new(io::ErrorKind::TimedOut, "socket timeout"));
let source = io_err.source().expect("Io variant should expose source");
assert_eq!(source.to_string(), "socket timeout");
}
#[test]
fn source_returns_none_for_non_io_variants() {
let errs = vec![
PgsqlError::Connection("c".to_string()),
PgsqlError::Timeout("t".to_string()),
PgsqlError::Protocol("p".to_string()),
PgsqlError::Auth("a".to_string()),
sample_query_error(),
PgsqlError::Pool("pool".to_string()),
];
for err in errs {
assert!(err.source().is_none());
}
}
#[test]
fn from_io_error_creates_io_variant() {
let err = io::Error::new(io::ErrorKind::NotFound, "missing file");
let pg_err: PgsqlError = err.into();
match pg_err {
PgsqlError::Io(inner) => {
assert_eq!(inner.kind(), io::ErrorKind::NotFound);
assert_eq!(inner.to_string(), "missing file");
}
other => panic!("expected Io variant, got {other:?}"),
}
}
#[test]
fn from_string_creates_protocol_variant() {
let pg_err: PgsqlError = String::from("bad response").into();
match pg_err {
PgsqlError::Protocol(msg) => assert_eq!(msg, "bad response"),
other => panic!("expected Protocol variant, got {other:?}"),
}
}
#[test]
fn from_str_creates_protocol_variant() {
let pg_err: PgsqlError = "decode error".into();
match pg_err {
PgsqlError::Protocol(msg) => assert_eq!(msg, "decode error"),
other => panic!("expected Protocol variant, got {other:?}"),
}
}
#[test]
fn from_pgsql_error_to_string_uses_display_output() {
let err = PgsqlError::Auth("login failed".to_string());
let value: String = err.into();
assert_eq!(value, "认证失败: login failed");
}
#[test]
fn debug_derive_formats_variant_details() {
let err = sample_query_error();
let debug = format!("{:?}", err);
assert!(debug.contains("Query"));
assert!(debug.contains("23505"));
assert!(debug.contains("duplicate key"));
assert!(debug.contains("position: 12"));
}
}