use crate::pool::PoolMessage;
#[derive(Debug)]
pub enum PoolError {
Database(tokio_postgres::Error),
Io(std::io::Error),
Recv(tokio::sync::oneshot::error::RecvError),
Send(Box<tokio::sync::mpsc::error::SendError<PoolMessage>>),
Tls(tokio_native_tls::native_tls::Error),
}
impl std::fmt::Display for PoolError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
PoolError::Database(err) => std::fmt::Display::fmt(err, f),
PoolError::Io(err) => std::fmt::Display::fmt(err, f),
PoolError::Tls(err) => std::fmt::Display::fmt(err, f),
PoolError::Recv(err) => std::fmt::Display::fmt(err, f),
PoolError::Send(err) => std::fmt::Display::fmt(err, f),
}
}
}
impl std::error::Error for PoolError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
PoolError::Database(err) => Some(err),
PoolError::Io(err) => Some(err),
PoolError::Tls(err) => Some(err),
PoolError::Recv(err) => Some(err),
PoolError::Send(err) => Some(err),
}
}
}
impl From<tokio_postgres::Error> for PoolError {
fn from(kind: tokio_postgres::Error) -> Self {
PoolError::Database(kind)
}
}
impl From<std::io::Error> for PoolError {
fn from(kind: std::io::Error) -> Self {
PoolError::Io(kind)
}
}
impl From<tokio::sync::oneshot::error::RecvError> for PoolError {
fn from(kind: tokio::sync::oneshot::error::RecvError) -> Self {
PoolError::Recv(kind)
}
}
impl From<tokio::sync::mpsc::error::SendError<PoolMessage>> for PoolError {
fn from(kind: tokio::sync::mpsc::error::SendError<PoolMessage>) -> Self {
PoolError::Send(Box::new(kind))
}
}
impl From<tokio_native_tls::native_tls::Error> for PoolError {
fn from(kind: tokio_native_tls::native_tls::Error) -> Self {
PoolError::Tls(kind)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::error::Error;
#[tokio::test]
async fn test_pool_error_database() {
let pg_error = tokio_postgres::connect("test.invalid", tokio_postgres::NoTls)
.await
.err()
.unwrap();
let error = PoolError::from(pg_error);
let source = error.source().unwrap();
let message = "invalid connection string: unexpected EOF";
assert_eq!(error.to_string(), message);
assert_eq!(source.to_string(), message);
}
#[test]
fn test_pool_error_io() {
let message = "io error";
let io_error = std::io::Error::other(message);
let error = PoolError::from(io_error);
let source = error.source().unwrap();
assert_eq!(error.to_string(), message);
assert_eq!(source.to_string(), message);
}
#[tokio::test]
async fn test_pool_error_recv() {
let (sender, receiver) = tokio::sync::oneshot::channel::<i32>();
drop(sender);
let recv_error = receiver.await.unwrap_err();
let error = PoolError::from(recv_error);
let source = error.source().unwrap();
let message = "channel closed";
assert_eq!(error.to_string(), message);
assert_eq!(source.to_string(), message);
}
#[tokio::test]
async fn test_pool_error_send() {
let (sender, receiver) = tokio::sync::mpsc::channel(1);
drop(receiver);
let send_error = sender
.send(PoolMessage::ReturnClient { client: None })
.await
.unwrap_err();
let error = PoolError::from(send_error);
let source = error.source().unwrap();
let message = "channel closed";
assert_eq!(error.to_string(), message);
assert_eq!(source.to_string(), message);
}
#[test]
fn test_pool_error_tls() {
let tls_error = tokio_native_tls::native_tls::Identity::from_pkcs8(&[0xff], &[0xff])
.err()
.unwrap();
let error = PoolError::from(tls_error);
let source = error.source().unwrap();
let message = "expected PKCS#8 PEM";
assert_eq!(error.to_string(), message);
assert_eq!(source.to_string(), message);
}
}