bestool_postgres/pool/
manager.rs

1use miette::Diagnostic;
2use mobc::{Manager, async_trait};
3use thiserror::Error;
4use tokio_postgres::{CancelToken, Client, Config, NoTls, error::DbError};
5
6#[derive(Debug, Clone)]
7pub struct PgConnectionManager {
8	config: Config,
9	tls: bool,
10}
11
12impl PgConnectionManager {
13	pub fn new(config: Config, tls: bool) -> Self {
14		Self { config, tls }
15	}
16
17	pub async fn cancel(&self, token: &CancelToken) -> Result<(), PgError> {
18		if self.tls {
19			let tls_connector = super::tls::make_tls_connector()?;
20			token.cancel_query(tls_connector).await?;
21		} else {
22			token.cancel_query(NoTls).await?;
23		}
24
25		Ok(())
26	}
27}
28
29#[derive(Error, Debug, Diagnostic)]
30pub enum PgError {
31	#[error("tls: {0}")]
32	Tls(#[from] rustls::Error),
33	#[error("postgres: {0}")]
34	Pg(#[from] tokio_postgres::Error),
35}
36
37impl PgError {
38	pub fn as_db_error(&self) -> Option<&DbError> {
39		match self {
40			PgError::Pg(e) => e.as_db_error(),
41			_ => None,
42		}
43	}
44}
45
46#[async_trait]
47impl Manager for PgConnectionManager {
48	type Connection = Client;
49	type Error = PgError;
50
51	async fn connect(&self) -> Result<Self::Connection, Self::Error> {
52		if self.tls {
53			let tls_connector = super::tls::make_tls_connector()?;
54			let (client, conn) = self.config.connect(tls_connector).await?;
55			mobc::spawn(conn);
56			Ok(client)
57		} else {
58			let (client, conn) = self.config.connect(NoTls).await?;
59			mobc::spawn(conn);
60			Ok(client)
61		}
62	}
63
64	async fn check(&self, conn: Self::Connection) -> Result<Self::Connection, Self::Error> {
65		conn.simple_query("").await?;
66		Ok(conn)
67	}
68}